From 91d11babd58ca955f451efb19a76ef5f3d2e9ea5 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 15 Mar 2022 20:20:36 +0800 Subject: [PATCH 1/4] enhance 1st conv layer of ResNet Signed-off-by: Yiheng Wang --- monai/networks/nets/resnet.py | 35 +++++++++++++++++++-------------- tests/test_resnet.py | 37 +++++++++++++++++++++++++++++++---- 2 files changed, 53 insertions(+), 19 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index a263c8e8b3..c4e4e024ea 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -10,7 +10,7 @@ # limitations under the License. from functools import partial -from typing import Any, Callable, List, Optional, Type, Union +from typing import Any, Callable, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -32,14 +32,6 @@ def get_avgpool(): return [0, 1, (1, 1), (1, 1, 1)] -def get_conv1(conv1_t_size: int, conv1_t_stride: int): - return ( - [0, conv1_t_size, (conv1_t_size, 7), (conv1_t_size, 7, 7)], - [0, conv1_t_stride, (conv1_t_stride, 2), (conv1_t_stride, 2, 2)], - [0, (conv1_t_size // 2), (conv1_t_size // 2, 3), (conv1_t_size // 2, 3, 3)], - ) - - class ResNetBlock(nn.Module): expansion = 1 @@ -184,8 +176,8 @@ def __init__( block_inplanes: List[int], spatial_dims: int = 3, n_input_channels: int = 3, - conv1_t_size: int = 7, - conv1_t_stride: int = 1, + conv1_t_size: Union[Tuple[int], int] = 7, + conv1_t_stride: Union[Tuple[int], int] = 1, no_max_pool: bool = False, shortcut_type: str = "B", widen_factor: float = 1.0, @@ -207,18 +199,31 @@ def __init__( ] block_avgpool = get_avgpool() - conv1_kernel, conv1_stride, conv1_padding = get_conv1(conv1_t_size, conv1_t_stride) block_inplanes = [int(x * widen_factor) for x in block_inplanes] self.in_planes = block_inplanes[0] self.no_max_pool = no_max_pool + if isinstance(conv1_t_size, int): + conv1_kernel_size = (conv1_t_size,) * spatial_dims + else: + if len(conv1_t_size) != spatial_dims: + raise ValueError("Tuple conv1_t_size should have length {spatial_dims}.") + conv1_kernel_size = conv1_t_size + + if isinstance(conv1_t_stride, int): + conv1_stride = (conv1_t_stride,) * spatial_dims + else: + if len(conv1_t_stride) != spatial_dims: + raise ValueError("Tuple conv1_t_stride should have length {spatial_dims}.") + conv1_stride = conv1_t_stride + self.conv1 = conv_type( n_input_channels, self.in_planes, - kernel_size=conv1_kernel[spatial_dims], - stride=conv1_stride[spatial_dims], - padding=conv1_padding[spatial_dims], + kernel_size=conv1_kernel_size, # type: ignore + stride=conv1_stride, # type: ignore + padding=tuple([k // 2 for k in conv1_kernel_size]), # type: ignore bias=False, ) self.bn1 = norm_type(self.in_planes) diff --git a/tests/test_resnet.py b/tests/test_resnet.py index ffb48125a4..688f7827b1 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -31,25 +31,54 @@ device = "cuda" if torch.cuda.is_available() else "cpu" TEST_CASE_1 = [ # 3D, batch 3, 2 input channel - {"pretrained": False, "spatial_dims": 3, "n_input_channels": 2, "num_classes": 3}, + { + "pretrained": False, + "spatial_dims": 3, + "n_input_channels": 2, + "num_classes": 3, + "conv1_t_size": 7, + "conv1_t_stride": (2, 2, 2), + }, (3, 2, 32, 64, 48), (3, 3), ] TEST_CASE_2 = [ # 2D, batch 2, 1 input channel - {"pretrained": False, "spatial_dims": 2, "n_input_channels": 1, "num_classes": 3}, + { + "pretrained": False, + "spatial_dims": 2, + "n_input_channels": 1, + "num_classes": 3, + "conv1_t_size": [7, 7], + "conv1_t_stride": [2, 2], + }, (2, 1, 32, 64), (2, 3), ] TEST_CASE_2_A = [ # 2D, batch 2, 1 input channel, shortcut type A - {"pretrained": False, "spatial_dims": 2, "n_input_channels": 1, "num_classes": 3, "shortcut_type": "A"}, + { + "pretrained": False, + "spatial_dims": 2, + "n_input_channels": 1, + "num_classes": 3, + "shortcut_type": "A", + "conv1_t_size": (7, 7), + "conv1_t_stride": 2, + }, (2, 1, 32, 64), (2, 3), ] TEST_CASE_3 = [ # 1D, batch 1, 2 input channels - {"pretrained": False, "spatial_dims": 1, "n_input_channels": 2, "num_classes": 3}, + { + "pretrained": False, + "spatial_dims": 1, + "n_input_channels": 2, + "num_classes": 3, + "conv1_t_size": [3], + "conv1_t_stride": 1, + }, (1, 2, 32), (1, 3), ] From 59b874aee1db70a7b929349676607b919f4551ad Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Mar 2022 12:21:47 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/nets/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index c4e4e024ea..67b7d828ba 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -223,7 +223,7 @@ def __init__( self.in_planes, kernel_size=conv1_kernel_size, # type: ignore stride=conv1_stride, # type: ignore - padding=tuple([k // 2 for k in conv1_kernel_size]), # type: ignore + padding=tuple(k // 2 for k in conv1_kernel_size), # type: ignore bias=False, ) self.bn1 = norm_type(self.in_planes) From 243aeac503891fdd549fce720dabb1caf395d8e1 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 15 Mar 2022 20:38:55 +0800 Subject: [PATCH 3/4] modify net adapeter unittest Signed-off-by: Yiheng Wang --- tests/test_net_adapter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_net_adapter.py b/tests/test_net_adapter.py index 0d73499a6d..39201fb600 100644 --- a/tests/test_net_adapter.py +++ b/tests/test_net_adapter.py @@ -41,7 +41,9 @@ class TestNetAdapter(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_shape(self, input_param, input_shape, expected_shape): - model = resnet18(spatial_dims=input_param["dim"]) + spatial_dims = input_param["dim"] + stride = (1, 2, 2)[:spatial_dims] + model = resnet18(spatial_dims=spatial_dims, conv1_t_stride=stride) input_param["model"] = model net = NetAdapter(**input_param).to(device) with eval_mode(net): From b9dae1ad944272a83838b51ca828f4d1f2a00be2 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 15 Mar 2022 20:44:51 +0800 Subject: [PATCH 4/4] use ensure_tuple_rep Signed-off-by: Yiheng Wang --- monai/networks/nets/resnet.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 67b7d828ba..c8be9f0e89 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -17,6 +17,7 @@ from monai.networks.layers.factories import Conv, Norm, Pool from monai.networks.layers.utils import get_pool_layer +from monai.utils import ensure_tuple_rep from monai.utils.module import look_up_option __all__ = ["ResNet", "resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"] @@ -204,19 +205,8 @@ def __init__( self.in_planes = block_inplanes[0] self.no_max_pool = no_max_pool - if isinstance(conv1_t_size, int): - conv1_kernel_size = (conv1_t_size,) * spatial_dims - else: - if len(conv1_t_size) != spatial_dims: - raise ValueError("Tuple conv1_t_size should have length {spatial_dims}.") - conv1_kernel_size = conv1_t_size - - if isinstance(conv1_t_stride, int): - conv1_stride = (conv1_t_stride,) * spatial_dims - else: - if len(conv1_t_stride) != spatial_dims: - raise ValueError("Tuple conv1_t_stride should have length {spatial_dims}.") - conv1_stride = conv1_t_stride + conv1_kernel_size = ensure_tuple_rep(conv1_t_size, spatial_dims) + conv1_stride = ensure_tuple_rep(conv1_t_stride, spatial_dims) self.conv1 = conv_type( n_input_channels,