From 024b73516bdc5c1399a01a5c336cf30b56061b55 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Thu, 18 Mar 2021 21:25:34 +0800 Subject: [PATCH 1/7] add pretrain options Signed-off-by: Yiheng Wang --- monai/networks/nets/densenet.py | 112 ++++++++++++++--------- monai/networks/nets/senet.py | 154 ++++++++++++++++++-------------- 2 files changed, 155 insertions(+), 111 deletions(-) diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index a59ab99e68..6db5233890 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -115,6 +115,11 @@ class DenseNet(nn.Module): bn_size: multiplicative factor for number of bottle neck layers. (i.e. bn_size * k features in the bottleneck layer) dropout_prob: dropout rate after each dense layer. + pretrained: whether to load ImageNet pretrained weights when `spatial_dims == 2`. + In order to load weights correctly, Please ensure that the `block_config` + is consistent with the corresponding arch. + pretrained_arch: the arch name for pretrained weights. + progress: If True, displays a progress bar of the download to stderr. """ def __init__( @@ -127,6 +132,9 @@ def __init__( block_config: Sequence[int] = (6, 12, 24, 16), bn_size: int = 4, dropout_prob: float = 0.0, + pretrained: bool = False, + pretrained_arch: str = "densenet121", + progress: bool = True, ) -> None: super(DenseNet, self).__init__() @@ -190,43 +198,48 @@ def __init__( elif isinstance(m, nn.Linear): nn.init.constant_(torch.as_tensor(m.bias), 0) + if pretrained: + self._load_state_dict(pretrained_arch, progress) + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) x = self.class_layers(x) return x + def _load_state_dict(self, arch, progress): + """ + This function is used to load pretrained models. + Adapted from `PyTorch Hub 2D version + `_ + """ + model_urls = { + "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", + "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", + "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", + } + if arch in model_urls.keys(): + model_url = model_urls[arch] + else: + error_msg = "only densenet121, densenet169 and densenet201 are supported to load pretrained weights." + raise AssertionError(error_msg) + pattern = re.compile( + r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + ) -model_urls = { - "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", - "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", - "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", -} - - -def _load_state_dict(model, model_url, progress): - """ - This function is used to load pretrained models. - Adapted from `PyTorch Hub 2D version - `_ - """ - pattern = re.compile( - r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" - ) - - state_dict = load_state_dict_from_url(model_url, progress=progress) - for key in list(state_dict.keys()): - res = pattern.match(key) - if res: - new_key = res.group(1) + ".layers" + res.group(2) + res.group(3) - state_dict[new_key] = state_dict[key] - del state_dict[key] + state_dict = load_state_dict_from_url(model_url, progress=progress) + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + ".layers" + res.group(2) + res.group(3) + state_dict[new_key] = state_dict[key] + del state_dict[key] - model_dict = model.state_dict() - state_dict = { - k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) - } - model_dict.update(state_dict) - model.load_state_dict(model_dict) + model_dict = self.state_dict() + state_dict = { + k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) + } + model_dict.update(state_dict) + self.load_state_dict(model_dict) def densenet121(pretrained: bool = False, progress: bool = True, **kwargs) -> DenseNet: @@ -235,10 +248,15 @@ def densenet121(pretrained: bool = False, progress: bool = True, **kwargs) -> De from `PyTorch Hub 2D version `_ """ - model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs) - if pretrained: - arch = "densenet121" - _load_state_dict(model, model_urls[arch], progress) + model = DenseNet( + init_features=64, + growth_rate=32, + block_config=(6, 12, 24, 16), + pretrained=pretrained, + pretrained_arch="densenet121", + progress=progress, + **kwargs, + ) return model @@ -248,10 +266,15 @@ def densenet169(pretrained: bool = False, progress: bool = True, **kwargs) -> De from `PyTorch Hub 2D version `_ """ - model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs) - if pretrained: - arch = "densenet169" - _load_state_dict(model, model_urls[arch], progress) + model = DenseNet( + init_features=64, + growth_rate=32, + block_config=(6, 12, 32, 32), + pretrained=pretrained, + pretrained_arch="densenet169", + progress=progress, + **kwargs, + ) return model @@ -261,10 +284,15 @@ def densenet201(pretrained: bool = False, progress: bool = True, **kwargs) -> De from `PyTorch Hub 2D version `_ """ - model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs) - if pretrained: - arch = "densenet201" - _load_state_dict(model, model_urls[arch], progress) + model = DenseNet( + init_features=64, + growth_rate=32, + block_config=(6, 12, 48, 32), + pretrained=pretrained, + pretrained_arch="densenet201", + progress=progress, + **kwargs, + ) return model diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index ef67f853d6..d6a657ae03 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -66,7 +66,11 @@ class SENet(nn.Module): - For SE-ResNeXt models: False num_classes: number of outputs in `last_linear` layer. for all models: 1000 - + pretrained: whether to load ImageNet pretrained weights when `spatial_dims == 2`. + In order to load weights correctly, Please ensure that the `block_config` + is consistent with the corresponding arch. + pretrained_arch: the arch name for pretrained weights. + progress: If True, displays a progress bar of the download to stderr. """ def __init__( @@ -83,6 +87,9 @@ def __init__( downsample_kernel_size: int = 3, input_3x3: bool = True, num_classes: int = 1000, + pretrained: bool = False, + pretrained_arch: str = "se_resnet50", + progress: bool = True, ) -> None: super(SENet, self).__init__() @@ -176,6 +183,65 @@ def __init__( elif isinstance(m, nn.Linear): nn.init.constant_(torch.as_tensor(m.bias), 0) + if pretrained: + self._load_state_dict(pretrained_arch, progress) + + def _load_state_dict(self, arch, progress): + """ + This function is used to load pretrained models. + """ + model_urls = { + "senet154": "http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth", + "se_resnet50": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth", + "se_resnet101": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth", + "se_resnet152": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth", + "se_resnext50_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth", + "se_resnext101_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth", + } + if arch in model_urls.keys(): + model_url = model_urls[arch] + else: + error_msg = ( + "only senet154, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d " + + "and se_resnext101_32x4d are supported to load pretrained weights." + ) + raise AssertionError(error_msg) + + pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$") + pattern_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$") + pattern_se = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc1.)(\w*)$") + pattern_se2 = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc2.)(\w*)$") + pattern_down_conv = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$") + pattern_down_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$") + + state_dict = load_state_dict_from_url(model_url, progress=progress) + for key in list(state_dict.keys()): + new_key = None + if pattern_conv.match(key): + new_key = re.sub(pattern_conv, r"\1conv.\2", key) + elif pattern_bn.match(key): + new_key = re.sub(pattern_bn, r"\1conv\2adn.N.\3", key) + elif pattern_se.match(key): + state_dict[key] = state_dict[key].squeeze() + new_key = re.sub(pattern_se, r"\1se_layer.fc.0.\2", key) + elif pattern_se2.match(key): + state_dict[key] = state_dict[key].squeeze() + new_key = re.sub(pattern_se2, r"\1se_layer.fc.2.\2", key) + elif pattern_down_conv.match(key): + new_key = re.sub(pattern_down_conv, r"\1project.conv.\2", key) + elif pattern_down_bn.match(key): + new_key = re.sub(pattern_down_bn, r"\1project.adn.N.\2", key) + if new_key: + state_dict[new_key] = state_dict[key] + del state_dict[key] + + model_dict = self.state_dict() + state_dict = { + k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) + } + model_dict.update(state_dict) + self.load_state_dict(model_dict) + def _make_layer( self, block: Type[Union[SEBottleneck, SEResNetBottleneck, SEResNeXtBottleneck]], @@ -248,56 +314,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -model_urls = { - "senet154": "http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth", - "se_resnet50": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth", - "se_resnet101": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth", - "se_resnet152": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth", - "se_resnext50_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth", - "se_resnext101_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth", -} - - -def _load_state_dict(model, model_url, progress): - """ - This function is used to load pretrained models. - """ - pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$") - pattern_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$") - pattern_se = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc1.)(\w*)$") - pattern_se2 = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc2.)(\w*)$") - pattern_down_conv = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$") - pattern_down_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$") - - state_dict = load_state_dict_from_url(model_url, progress=progress) - for key in list(state_dict.keys()): - new_key = None - if pattern_conv.match(key): - new_key = re.sub(pattern_conv, r"\1conv.\2", key) - elif pattern_bn.match(key): - new_key = re.sub(pattern_bn, r"\1conv\2adn.N.\3", key) - elif pattern_se.match(key): - state_dict[key] = state_dict[key].squeeze() - new_key = re.sub(pattern_se, r"\1se_layer.fc.0.\2", key) - elif pattern_se2.match(key): - state_dict[key] = state_dict[key].squeeze() - new_key = re.sub(pattern_se2, r"\1se_layer.fc.2.\2", key) - elif pattern_down_conv.match(key): - new_key = re.sub(pattern_down_conv, r"\1project.conv.\2", key) - elif pattern_down_bn.match(key): - new_key = re.sub(pattern_down_bn, r"\1project.adn.N.\2", key) - if new_key: - state_dict[new_key] = state_dict[key] - del state_dict[key] - - model_dict = model.state_dict() - state_dict = { - k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) - } - model_dict.update(state_dict) - model.load_state_dict(model_dict) - - def senet154( spatial_dims: int, in_channels: int, @@ -320,10 +336,10 @@ def senet154( dropout_prob=0.2, dropout_dim=1, num_classes=num_classes, + pretrained=pretrained, + pretrained_arch="senet154", + progress=progress, ) - if pretrained: - arch = "senet154" - _load_state_dict(model, model_urls[arch], progress) return model @@ -347,10 +363,10 @@ def se_resnet50( input_3x3=False, downsample_kernel_size=1, num_classes=num_classes, + pretrained=pretrained, + pretrained_arch="se_resnet50", + progress=progress, ) - if pretrained: - arch = "se_resnet50" - _load_state_dict(model, model_urls[arch], progress) return model @@ -375,10 +391,10 @@ def se_resnet101( input_3x3=False, downsample_kernel_size=1, num_classes=num_classes, + pretrained=pretrained, + pretrained_arch="se_resnet101", + progress=progress, ) - if pretrained: - arch = "se_resnet101" - _load_state_dict(model, model_urls[arch], progress) return model @@ -403,10 +419,10 @@ def se_resnet152( input_3x3=False, downsample_kernel_size=1, num_classes=num_classes, + pretrained=pretrained, + pretrained_arch="se_resnet152", + progress=progress, ) - if pretrained: - arch = "se_resnet152" - _load_state_dict(model, model_urls[arch], progress) return model @@ -430,10 +446,10 @@ def se_resnext50_32x4d( input_3x3=False, downsample_kernel_size=1, num_classes=num_classes, + pretrained=pretrained, + pretrained_arch="se_resnext50_32x4d", + progress=progress, ) - if pretrained: - arch = "se_resnext50_32x4d" - _load_state_dict(model, model_urls[arch], progress) return model @@ -457,8 +473,8 @@ def se_resnext101_32x4d( input_3x3=False, downsample_kernel_size=1, num_classes=num_classes, + pretrained=pretrained, + pretrained_arch="se_resnext101_32x4d", + progress=progress, ) - if pretrained: - arch = "se_resnext101_32x4d" - _load_state_dict(model, model_urls[arch], progress) return model From 1fe5d1def266eeef3e448fce2cfc9ac94c9812f1 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 19 Mar 2021 13:56:34 +0800 Subject: [PATCH 2/7] rewrite error message add test cases Signed-off-by: Yiheng Wang --- monai/networks/nets/densenet.py | 5 +++-- monai/networks/nets/senet.py | 7 +++---- tests/test_densenet.py | 18 +++++++++++++++++- tests/test_senet.py | 26 +++++++++++++++++++++++--- 4 files changed, 46 insertions(+), 10 deletions(-) diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index 6db5233890..4b4f2cc6a4 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -220,8 +220,9 @@ def _load_state_dict(self, arch, progress): if arch in model_urls.keys(): model_url = model_urls[arch] else: - error_msg = "only densenet121, densenet169 and densenet201 are supported to load pretrained weights." - raise AssertionError(error_msg) + raise ValueError( + "only 'densenet121', 'densenet169' and 'densenet201' are supported to load pretrained weights." + ) pattern = re.compile( r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" ) diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index d6a657ae03..333a3b1159 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -201,11 +201,10 @@ def _load_state_dict(self, arch, progress): if arch in model_urls.keys(): model_url = model_urls[arch] else: - error_msg = ( - "only senet154, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d " - + "and se_resnext101_32x4d are supported to load pretrained weights." + raise ValueError( + "only 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', \ + and se_resnext101_32x4d are supported to load pretrained weights." ) - raise AssertionError(error_msg) pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$") pattern_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$") diff --git a/tests/test_densenet.py b/tests/test_densenet.py index 41b5fbf7d6..5ead5f5818 100644 --- a/tests/test_densenet.py +++ b/tests/test_densenet.py @@ -17,7 +17,7 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.nets import densenet121, densenet169, densenet201, densenet264 +from monai.networks.nets import DenseNet, densenet121, densenet169, densenet201, densenet264 from monai.utils import optional_import from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save @@ -78,6 +78,17 @@ (1, 3, 32, 32), ] +TEST_PRETRAINED_2D_CASE_4 = [ + { + "pretrained": True, + "pretrained_arch": "densenet264", + "progress": False, + "spatial_dims": 2, + "in_channels": 3, + "out_channels": 1, + }, +] + class TestPretrainedDENSENET(unittest.TestCase): @parameterized.expand([TEST_PRETRAINED_2D_CASE_1, TEST_PRETRAINED_2D_CASE_2]) @@ -100,6 +111,11 @@ def test_pretrain_consistency(self, model, input_param, input_shape): expected_result = torchvision_net.features.forward(example) self.assertTrue(torch.all(result == expected_result)) + @parameterized.expand([TEST_PRETRAINED_2D_CASE_4]) + def test_ill_pretrain(self, input_param): + with self.assertRaisesRegex(ValueError, ""): + net = DenseNet(**input_param) + class TestDENSENET(unittest.TestCase): @parameterized.expand(TEST_CASES) diff --git a/tests/test_senet.py b/tests/test_senet.py index c1327ceb7d..a2d96e1f18 100644 --- a/tests/test_senet.py +++ b/tests/test_senet.py @@ -17,7 +17,9 @@ from parameterized import parameterized from monai.networks import eval_mode +from monai.networks.blocks.squeeze_and_excitation import SEBottleneck from monai.networks.nets import ( + SENet, se_resnet50, se_resnet101, se_resnet152, @@ -46,7 +48,20 @@ TEST_CASE_5 = [se_resnext50_32x4d, NET_ARGS] TEST_CASE_6 = [se_resnext101_32x4d, NET_ARGS] -TEST_CASE_PRETRAINED = [se_resnet50, {"spatial_dims": 2, "in_channels": 3, "num_classes": 2, "pretrained": True}] +TEST_CASE_PRETRAINED_1 = [se_resnet50, {"spatial_dims": 2, "in_channels": 3, "num_classes": 2, "pretrained": True}] +TEST_CASE_PRETRAINED_2 = [ + { + "spatial_dims": 2, + "in_channels": 3, + "block": SEBottleneck, + "layers": [3, 8, 36, 3], + "groups": 64, + "reduction": 16, + "num_classes": 2, + "pretrained": True, + "pretrained_arch": "resnet50", + } +] class TestSENET(unittest.TestCase): @@ -67,7 +82,7 @@ def test_script(self, net, net_args): class TestPretrainedSENET(unittest.TestCase): - @parameterized.expand([TEST_CASE_PRETRAINED]) + @parameterized.expand([TEST_CASE_PRETRAINED_1]) def test_senet_shape(self, model, input_param): net = test_pretrained_networks(model, input_param, device) input_data = torch.randn(3, 3, 64, 64).to(device) @@ -77,7 +92,7 @@ def test_senet_shape(self, model, input_param): result = net(input_data) self.assertEqual(result.shape, expected_shape) - @parameterized.expand([TEST_CASE_PRETRAINED]) + @parameterized.expand([TEST_CASE_PRETRAINED_1]) @skipUnless(has_cadene_pretrain, "Requires `pretrainedmodels` package.") def test_pretrain_consistency(self, model, input_param): input_data = torch.randn(1, 3, 64, 64).to(device) @@ -92,6 +107,11 @@ def test_pretrain_consistency(self, model, input_param): # a conv layer with kernel size equals to 1. It may bring a little difference. self.assertTrue(torch.allclose(result, expected_result, rtol=1e-5, atol=1e-5)) + @parameterized.expand([TEST_CASE_PRETRAINED_2]) + def test_ill_pretrain(self, input_param): + with self.assertRaisesRegex(ValueError, ""): + net = SENet(**input_param) + if __name__ == "__main__": unittest.main() From 16b2a9b4375b34903da65f754a4b7a89fc125e6c Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 20 Mar 2021 00:10:33 +0800 Subject: [PATCH 3/7] Use subclass for densenet Signed-off-by: Yiheng Wang --- monai/networks/nets/__init__.py | 2 +- monai/networks/nets/densenet.py | 198 +++++++++++++++++--------------- tests/test_densenet.py | 34 ++---- 3 files changed, 115 insertions(+), 119 deletions(-) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index f3def30736..c236df11db 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -13,7 +13,7 @@ from .autoencoder import AutoEncoder from .basic_unet import BasicUNet, BasicUnet, Basicunet from .classifier import Classifier, Critic, Discriminator -from .densenet import DenseNet, densenet121, densenet169, densenet201, densenet264 +from .densenet import DenseNet, DenseNet121, DenseNet169, DenseNet201, DenseNet264 from .dynunet import DynUNet, DynUnet, Dynunet from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet from .generator import Generator diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index 4b4f2cc6a4..d8ad8f89d3 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -115,11 +115,6 @@ class DenseNet(nn.Module): bn_size: multiplicative factor for number of bottle neck layers. (i.e. bn_size * k features in the bottleneck layer) dropout_prob: dropout rate after each dense layer. - pretrained: whether to load ImageNet pretrained weights when `spatial_dims == 2`. - In order to load weights correctly, Please ensure that the `block_config` - is consistent with the corresponding arch. - pretrained_arch: the arch name for pretrained weights. - progress: If True, displays a progress bar of the download to stderr. """ def __init__( @@ -132,9 +127,6 @@ def __init__( block_config: Sequence[int] = (6, 12, 24, 16), bn_size: int = 4, dropout_prob: float = 0.0, - pretrained: bool = False, - pretrained_arch: str = "densenet121", - progress: bool = True, ) -> None: super(DenseNet, self).__init__() @@ -198,107 +190,127 @@ def __init__( elif isinstance(m, nn.Linear): nn.init.constant_(torch.as_tensor(m.bias), 0) - if pretrained: - self._load_state_dict(pretrained_arch, progress) - def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) x = self.class_layers(x) return x - def _load_state_dict(self, arch, progress): - """ - This function is used to load pretrained models. - Adapted from `PyTorch Hub 2D version - `_ - """ - model_urls = { - "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", - "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", - "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", - } - if arch in model_urls.keys(): - model_url = model_urls[arch] - else: - raise ValueError( - "only 'densenet121', 'densenet169' and 'densenet201' are supported to load pretrained weights." - ) - pattern = re.compile( - r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + +def _load_state_dict(model, arch, progress): + """ + This function is used to load pretrained models. + Adapted from `PyTorch Hub 2D version + `_ + """ + model_urls = { + "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", + "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", + "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", + } + if arch in model_urls.keys(): + model_url = model_urls[arch] + else: + raise ValueError( + "only 'densenet121', 'densenet169' and 'densenet201' are supported to load pretrained weights." ) + pattern = re.compile( + r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + ) - state_dict = load_state_dict_from_url(model_url, progress=progress) - for key in list(state_dict.keys()): - res = pattern.match(key) - if res: - new_key = res.group(1) + ".layers" + res.group(2) + res.group(3) - state_dict[new_key] = state_dict[key] - del state_dict[key] + state_dict = load_state_dict_from_url(model_url, progress=progress) + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + ".layers" + res.group(2) + res.group(3) + state_dict[new_key] = state_dict[key] + del state_dict[key] - model_dict = self.state_dict() - state_dict = { - k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) - } - model_dict.update(state_dict) - self.load_state_dict(model_dict) + model_dict = model.state_dict() + state_dict = { + k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) + } + model_dict.update(state_dict) + model.load_state_dict(model_dict) -def densenet121(pretrained: bool = False, progress: bool = True, **kwargs) -> DenseNet: - """ - when `spatial_dims = 2`, specify `pretrained = True` can load Imagenet pretrained weights achieved - from `PyTorch Hub 2D version - `_ - """ - model = DenseNet( - init_features=64, - growth_rate=32, - block_config=(6, 12, 24, 16), - pretrained=pretrained, - pretrained_arch="densenet121", - progress=progress, +class DenseNet121(DenseNet): + def __init__( + self, + init_features: int = 64, + growth_rate: int = 32, + block_config: Sequence[int] = (6, 12, 24, 16), + pretrained: bool = False, + progress: bool = True, **kwargs, - ) - return model + ) -> None: + super(DenseNet121, self).__init__( + init_features=init_features, + growth_rate=growth_rate, + block_config=block_config, + **kwargs, + ) + if pretrained: + # it only worked when `spatial_dims` is 2 + _load_state_dict(self, "densenet121", progress) -def densenet169(pretrained: bool = False, progress: bool = True, **kwargs) -> DenseNet: - """ - when `spatial_dims = 2`, specify `pretrained = True` can load Imagenet pretrained weights achieved - from `PyTorch Hub 2D version - `_ - """ - model = DenseNet( - init_features=64, - growth_rate=32, - block_config=(6, 12, 32, 32), - pretrained=pretrained, - pretrained_arch="densenet169", - progress=progress, +class DenseNet169(DenseNet): + def __init__( + self, + init_features: int = 64, + growth_rate: int = 32, + block_config: Sequence[int] = (6, 12, 32, 32), + pretrained: bool = False, + progress: bool = True, **kwargs, - ) - return model + ) -> None: + super(DenseNet169, self).__init__( + init_features=init_features, + growth_rate=growth_rate, + block_config=block_config, + **kwargs, + ) + if pretrained: + # it only worked when `spatial_dims` is 2 + _load_state_dict(self, "densenet169", progress) -def densenet201(pretrained: bool = False, progress: bool = True, **kwargs) -> DenseNet: - """ - when `spatial_dims = 2`, specify `pretrained = True` can load Imagenet pretrained weights achieved - from `PyTorch Hub 2D version - `_ - """ - model = DenseNet( - init_features=64, - growth_rate=32, - block_config=(6, 12, 48, 32), - pretrained=pretrained, - pretrained_arch="densenet201", - progress=progress, +class DenseNet201(DenseNet): + def __init__( + self, + init_features: int = 64, + growth_rate: int = 32, + block_config: Sequence[int] = (6, 12, 48, 32), + pretrained: bool = False, + progress: bool = True, **kwargs, - ) - return model + ) -> None: + super(DenseNet201, self).__init__( + init_features=init_features, + growth_rate=growth_rate, + block_config=block_config, + **kwargs, + ) + if pretrained: + # it only worked when `spatial_dims` is 2 + _load_state_dict(self, "densenet201", progress) -def densenet264(pretrained: bool = False, progress: bool = True, **kwargs) -> DenseNet: - model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 64, 48), **kwargs) - if pretrained: - print("Currently PyTorch Hub does not provide densenet264 pretrained models.") - return model +class DenseNet264(DenseNet): + def __init__( + self, + init_features: int = 64, + growth_rate: int = 32, + block_config: Sequence[int] = (6, 12, 48, 32), + pretrained: bool = False, + progress: bool = True, + **kwargs, + ) -> None: + super(DenseNet264, self).__init__( + init_features=init_features, + growth_rate=growth_rate, + block_config=block_config, + **kwargs, + ) + if pretrained: + print("Currently PyTorch Hub does not provide densenet264 pretrained models.") diff --git a/tests/test_densenet.py b/tests/test_densenet.py index 5ead5f5818..c934841598 100644 --- a/tests/test_densenet.py +++ b/tests/test_densenet.py @@ -17,9 +17,9 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.nets import DenseNet, densenet121, densenet169, densenet201, densenet264 +from monai.networks.nets import DenseNet121, DenseNet169, DenseNet201, DenseNet264 from monai.utils import optional_import -from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save +from tests.utils import skip_if_quick, test_script_save if TYPE_CHECKING: import torchvision @@ -51,50 +51,39 @@ TEST_CASES = [] for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]: - for model in [densenet121, densenet169, densenet201, densenet264]: + for model in [DenseNet121, DenseNet169, DenseNet201, DenseNet264]: TEST_CASES.append([model, *case]) -TEST_SCRIPT_CASES = [[model, *TEST_CASE_1] for model in [densenet121, densenet169, densenet201, densenet264]] +TEST_SCRIPT_CASES = [[model, *TEST_CASE_1] for model in [DenseNet121, DenseNet169, DenseNet201, DenseNet264]] TEST_PRETRAINED_2D_CASE_1 = [ # 4-channel 2D, batch 2 - densenet121, + DenseNet121, {"pretrained": True, "progress": True, "spatial_dims": 2, "in_channels": 2, "out_channels": 3}, (1, 2, 32, 64), (1, 3), ] TEST_PRETRAINED_2D_CASE_2 = [ # 4-channel 2D, batch 2 - densenet121, + DenseNet121, {"pretrained": True, "progress": False, "spatial_dims": 2, "in_channels": 2, "out_channels": 1}, (1, 2, 32, 64), (1, 1), ] TEST_PRETRAINED_2D_CASE_3 = [ - densenet121, + DenseNet121, {"pretrained": True, "progress": False, "spatial_dims": 2, "in_channels": 3, "out_channels": 1}, (1, 3, 32, 32), ] -TEST_PRETRAINED_2D_CASE_4 = [ - { - "pretrained": True, - "pretrained_arch": "densenet264", - "progress": False, - "spatial_dims": 2, - "in_channels": 3, - "out_channels": 1, - }, -] - class TestPretrainedDENSENET(unittest.TestCase): @parameterized.expand([TEST_PRETRAINED_2D_CASE_1, TEST_PRETRAINED_2D_CASE_2]) @skip_if_quick def test_121_2d_shape_pretrain(self, model, input_param, input_shape, expected_shape): - net = test_pretrained_networks(model, input_param, device) + net = model(**input_param).to(device) with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -103,7 +92,7 @@ def test_121_2d_shape_pretrain(self, model, input_param, input_shape, expected_s @skipUnless(has_torchvision, "Requires `torchvision` package.") def test_pretrain_consistency(self, model, input_param, input_shape): example = torch.randn(input_shape).to(device) - net = test_pretrained_networks(model, input_param, device) + net = model(**input_param).to(device) with eval_mode(net): result = net.features.forward(example) torchvision_net = torchvision.models.densenet121(pretrained=True).to(device) @@ -111,11 +100,6 @@ def test_pretrain_consistency(self, model, input_param, input_shape): expected_result = torchvision_net.features.forward(example) self.assertTrue(torch.all(result == expected_result)) - @parameterized.expand([TEST_PRETRAINED_2D_CASE_4]) - def test_ill_pretrain(self, input_param): - with self.assertRaisesRegex(ValueError, ""): - net = DenseNet(**input_param) - class TestDENSENET(unittest.TestCase): @parameterized.expand(TEST_CASES) From db480fc305140db2c48e819f85f2c9a3b9f35464 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 20 Mar 2021 17:06:52 +0800 Subject: [PATCH 4/7] Use subclass for senet Signed-off-by: Yiheng Wang --- monai/networks/nets/__init__.py | 2 +- monai/networks/nets/senet.py | 448 ++++++++++++++++---------------- tests/test_senet.py | 45 +--- 3 files changed, 230 insertions(+), 265 deletions(-) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index c236df11db..cd00ea1aa1 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -21,7 +21,7 @@ from .regressor import Regressor from .regunet import GlobalNet, LocalNet, RegUNet from .segresnet import SegResNet, SegResNetVAE -from .senet import SENet, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d, senet154 +from .senet import SENet, SENet154, SEResNet50, SEResNet101, SEResNet152, SEResNext50, SEResNext101 from .unet import UNet, Unet, unet from .varautoencoder import VarAutoEncoder from .vnet import VNet diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index 333a3b1159..704674a724 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -66,11 +66,6 @@ class SENet(nn.Module): - For SE-ResNeXt models: False num_classes: number of outputs in `last_linear` layer. for all models: 1000 - pretrained: whether to load ImageNet pretrained weights when `spatial_dims == 2`. - In order to load weights correctly, Please ensure that the `block_config` - is consistent with the corresponding arch. - pretrained_arch: the arch name for pretrained weights. - progress: If True, displays a progress bar of the download to stderr. """ def __init__( @@ -87,9 +82,6 @@ def __init__( downsample_kernel_size: int = 3, input_3x3: bool = True, num_classes: int = 1000, - pretrained: bool = False, - pretrained_arch: str = "se_resnet50", - progress: bool = True, ) -> None: super(SENet, self).__init__() @@ -183,64 +175,6 @@ def __init__( elif isinstance(m, nn.Linear): nn.init.constant_(torch.as_tensor(m.bias), 0) - if pretrained: - self._load_state_dict(pretrained_arch, progress) - - def _load_state_dict(self, arch, progress): - """ - This function is used to load pretrained models. - """ - model_urls = { - "senet154": "http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth", - "se_resnet50": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth", - "se_resnet101": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth", - "se_resnet152": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth", - "se_resnext50_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth", - "se_resnext101_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth", - } - if arch in model_urls.keys(): - model_url = model_urls[arch] - else: - raise ValueError( - "only 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', \ - and se_resnext101_32x4d are supported to load pretrained weights." - ) - - pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$") - pattern_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$") - pattern_se = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc1.)(\w*)$") - pattern_se2 = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc2.)(\w*)$") - pattern_down_conv = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$") - pattern_down_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$") - - state_dict = load_state_dict_from_url(model_url, progress=progress) - for key in list(state_dict.keys()): - new_key = None - if pattern_conv.match(key): - new_key = re.sub(pattern_conv, r"\1conv.\2", key) - elif pattern_bn.match(key): - new_key = re.sub(pattern_bn, r"\1conv\2adn.N.\3", key) - elif pattern_se.match(key): - state_dict[key] = state_dict[key].squeeze() - new_key = re.sub(pattern_se, r"\1se_layer.fc.0.\2", key) - elif pattern_se2.match(key): - state_dict[key] = state_dict[key].squeeze() - new_key = re.sub(pattern_se2, r"\1se_layer.fc.2.\2", key) - elif pattern_down_conv.match(key): - new_key = re.sub(pattern_down_conv, r"\1project.conv.\2", key) - elif pattern_down_bn.match(key): - new_key = re.sub(pattern_down_bn, r"\1project.adn.N.\2", key) - if new_key: - state_dict[new_key] = state_dict[key] - del state_dict[key] - - model_dict = self.state_dict() - state_dict = { - k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) - } - model_dict.update(state_dict) - self.load_state_dict(model_dict) - def _make_layer( self, block: Type[Union[SEBottleneck, SEResNetBottleneck, SEResNeXtBottleneck]], @@ -313,167 +247,225 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def senet154( - spatial_dims: int, - in_channels: int, - num_classes: int, - pretrained: bool = False, - progress: bool = True, -) -> SENet: +def _load_state_dict(model, arch, progress): """ - when `spatial_dims = 2`, specify `pretrained = True` can load Imagenet pretrained weights achieved - from `Cadene Hub 2D version - `_. - """ - model = SENet( - spatial_dims=spatial_dims, - in_channels=in_channels, - block=SEBottleneck, - layers=[3, 8, 36, 3], - groups=64, - reduction=16, - dropout_prob=0.2, - dropout_dim=1, - num_classes=num_classes, - pretrained=pretrained, - pretrained_arch="senet154", - progress=progress, - ) - return model - - -def se_resnet50( - spatial_dims: int, in_channels: int, num_classes: int, pretrained: bool = False, progress: bool = True -) -> SENet: - """ - when `spatial_dims = 2`, specify `pretrained = True` can load Imagenet pretrained weights achieved - from `Cadene Hub 2D version - `_. - """ - model = SENet( - spatial_dims=spatial_dims, - in_channels=in_channels, - block=SEResNetBottleneck, - layers=[3, 4, 6, 3], - groups=1, - reduction=16, - dropout_prob=None, - inplanes=64, - input_3x3=False, - downsample_kernel_size=1, - num_classes=num_classes, - pretrained=pretrained, - pretrained_arch="se_resnet50", - progress=progress, - ) - return model - - -def se_resnet101( - spatial_dims: int, in_channels: int, num_classes: int, pretrained: bool = False, progress: bool = True -) -> SENet: + This function is used to load pretrained models. """ - when `spatial_dims = 2`, specify `pretrained = True` can load Imagenet pretrained weights achieved - from `Cadene Hub 2D version - `_. - """ - model = SENet( - spatial_dims=spatial_dims, - in_channels=in_channels, - block=SEResNetBottleneck, - layers=[3, 4, 23, 3], - groups=1, - reduction=16, - dropout_prob=0.2, - dropout_dim=1, - inplanes=64, - input_3x3=False, - downsample_kernel_size=1, - num_classes=num_classes, - pretrained=pretrained, - pretrained_arch="se_resnet101", - progress=progress, - ) - return model - - -def se_resnet152( - spatial_dims: int, in_channels: int, num_classes: int, pretrained: bool = False, progress: bool = True -) -> SENet: - """ - when `spatial_dims = 2`, specify `pretrained = True` can load Imagenet pretrained weights achieved - from `Cadene Hub 2D version - `_. - """ - model = SENet( - spatial_dims=spatial_dims, - in_channels=in_channels, - block=SEResNetBottleneck, - layers=[3, 8, 36, 3], - groups=1, - reduction=16, - dropout_prob=0.2, - dropout_dim=1, - inplanes=64, - input_3x3=False, - downsample_kernel_size=1, - num_classes=num_classes, - pretrained=pretrained, - pretrained_arch="se_resnet152", - progress=progress, - ) - return model - - -def se_resnext50_32x4d( - spatial_dims: int, in_channels: int, num_classes: int, pretrained: bool = False, progress: bool = True -) -> SENet: - """ - when `spatial_dims = 2`, specify `pretrained = True` can load Imagenet pretrained weights achieved - from `Cadene Hub 2D version - `_. - """ - model = SENet( - spatial_dims=spatial_dims, - in_channels=in_channels, - block=SEResNeXtBottleneck, - layers=[3, 4, 6, 3], - groups=32, - reduction=16, - dropout_prob=None, - inplanes=64, - input_3x3=False, - downsample_kernel_size=1, - num_classes=num_classes, - pretrained=pretrained, - pretrained_arch="se_resnext50_32x4d", - progress=progress, - ) - return model - - -def se_resnext101_32x4d( - spatial_dims: int, in_channels: int, num_classes: int, pretrained: bool = False, progress: bool = True -) -> SENet: - """ - when `spatial_dims = 2`, specify `pretrained = True` can load Imagenet pretrained weights achieved - from `Cadene Hub 2D version - `_. - """ - model = SENet( - spatial_dims=spatial_dims, - in_channels=in_channels, - block=SEResNeXtBottleneck, - layers=[3, 4, 23, 3], - groups=32, - reduction=16, - dropout_prob=None, - inplanes=64, - input_3x3=False, - downsample_kernel_size=1, - num_classes=num_classes, - pretrained=pretrained, - pretrained_arch="se_resnext101_32x4d", - progress=progress, - ) - return model + model_urls = { + "senet154": "http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth", + "se_resnet50": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth", + "se_resnet101": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth", + "se_resnet152": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth", + "se_resnext50_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth", + "se_resnext101_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth", + } + if arch in model_urls.keys(): + model_url = model_urls[arch] + else: + raise ValueError( + "only 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', \ + and se_resnext101_32x4d are supported to load pretrained weights." + ) + + pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$") + pattern_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$") + pattern_se = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc1.)(\w*)$") + pattern_se2 = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc2.)(\w*)$") + pattern_down_conv = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$") + pattern_down_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$") + + state_dict = load_state_dict_from_url(model_url, progress=progress) + for key in list(state_dict.keys()): + new_key = None + if pattern_conv.match(key): + new_key = re.sub(pattern_conv, r"\1conv.\2", key) + elif pattern_bn.match(key): + new_key = re.sub(pattern_bn, r"\1conv\2adn.N.\3", key) + elif pattern_se.match(key): + state_dict[key] = state_dict[key].squeeze() + new_key = re.sub(pattern_se, r"\1se_layer.fc.0.\2", key) + elif pattern_se2.match(key): + state_dict[key] = state_dict[key].squeeze() + new_key = re.sub(pattern_se2, r"\1se_layer.fc.2.\2", key) + elif pattern_down_conv.match(key): + new_key = re.sub(pattern_down_conv, r"\1project.conv.\2", key) + elif pattern_down_bn.match(key): + new_key = re.sub(pattern_down_bn, r"\1project.adn.N.\2", key) + if new_key: + state_dict[new_key] = state_dict[key] + del state_dict[key] + + model_dict = model.state_dict() + state_dict = { + k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape) + } + model_dict.update(state_dict) + model.load_state_dict(model_dict) + + +class SENet154(SENet): + def __init__( + self, + layers: List[int] = [3, 8, 36, 3], + groups: int = 64, + reduction: int = 16, + pretrained: bool = False, + progress: bool = True, + **kwargs, + ) -> None: + super(SENet154, self).__init__( + block=SEBottleneck, + layers=layers, + groups=groups, + reduction=reduction, + **kwargs, + ) + if pretrained: + # it only worked when `spatial_dims` is 2 + _load_state_dict(self, "senet154", progress) + + +class SEResNet50(SENet): + def __init__( + self, + layers: List[int] = [3, 4, 6, 3], + groups: int = 1, + reduction: int = 16, + dropout_prob: Optional[float] = None, + inplanes: int = 64, + downsample_kernel_size: int = 1, + input_3x3: bool = False, + pretrained: bool = False, + progress: bool = True, + **kwargs, + ) -> None: + super(SEResNet50, self).__init__( + block=SEResNetBottleneck, + layers=layers, + groups=groups, + reduction=reduction, + dropout_prob=dropout_prob, + inplanes=inplanes, + downsample_kernel_size=downsample_kernel_size, + input_3x3=input_3x3, + **kwargs, + ) + if pretrained: + # it only worked when `spatial_dims` is 2 + _load_state_dict(self, "se_resnet50", progress) + + +class SEResNet101(SENet): + def __init__( + self, + layers: List[int] = [3, 4, 23, 3], + groups: int = 1, + reduction: int = 16, + inplanes: int = 64, + downsample_kernel_size: int = 1, + input_3x3: bool = False, + pretrained: bool = False, + progress: bool = True, + **kwargs, + ) -> None: + super(SEResNet101, self).__init__( + block=SEResNetBottleneck, + layers=layers, + groups=groups, + reduction=reduction, + inplanes=inplanes, + downsample_kernel_size=downsample_kernel_size, + input_3x3=input_3x3, + **kwargs, + ) + if pretrained: + # it only worked when `spatial_dims` is 2 + _load_state_dict(self, "se_resnet101", progress) + + +class SEResNet152(SENet): + def __init__( + self, + layers: List[int] = [3, 8, 36, 3], + groups: int = 1, + reduction: int = 16, + inplanes: int = 64, + downsample_kernel_size: int = 1, + input_3x3: bool = False, + pretrained: bool = False, + progress: bool = True, + **kwargs, + ) -> None: + super(SEResNet152, self).__init__( + block=SEResNetBottleneck, + layers=layers, + groups=groups, + reduction=reduction, + inplanes=inplanes, + downsample_kernel_size=downsample_kernel_size, + input_3x3=input_3x3, + **kwargs, + ) + if pretrained: + # it only worked when `spatial_dims` is 2 + _load_state_dict(self, "se_resnet152", progress) + + +class SEResNext50(SENet): + def __init__( + self, + layers: List[int] = [3, 4, 6, 3], + groups: int = 32, + reduction: int = 16, + dropout_prob: Optional[float] = None, + inplanes: int = 64, + downsample_kernel_size: int = 1, + input_3x3: bool = False, + pretrained: bool = False, + progress: bool = True, + **kwargs, + ) -> None: + super(SEResNext50, self).__init__( + block=SEResNeXtBottleneck, + layers=layers, + groups=groups, + dropout_prob=dropout_prob, + reduction=reduction, + inplanes=inplanes, + downsample_kernel_size=downsample_kernel_size, + input_3x3=input_3x3, + **kwargs, + ) + if pretrained: + # it only worked when `spatial_dims` is 2 + _load_state_dict(self, "se_resnext50_32x4d", progress) + + +class SEResNext101(SENet): + def __init__( + self, + layers: List[int] = [3, 4, 23, 3], + groups: int = 32, + reduction: int = 16, + dropout_prob: Optional[float] = None, + inplanes: int = 64, + downsample_kernel_size: int = 1, + input_3x3: bool = False, + pretrained: bool = False, + progress: bool = True, + **kwargs, + ) -> None: + super(SEResNext101, self).__init__( + block=SEResNeXtBottleneck, + layers=layers, + groups=groups, + dropout_prob=dropout_prob, + reduction=reduction, + inplanes=inplanes, + downsample_kernel_size=downsample_kernel_size, + input_3x3=input_3x3, + **kwargs, + ) + if pretrained: + # it only worked when `spatial_dims` is 2 + _load_state_dict(self, "se_resnext101_32x4d", progress) diff --git a/tests/test_senet.py b/tests/test_senet.py index a2d96e1f18..1c6222d6a0 100644 --- a/tests/test_senet.py +++ b/tests/test_senet.py @@ -17,16 +17,7 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.blocks.squeeze_and_excitation import SEBottleneck -from monai.networks.nets import ( - SENet, - se_resnet50, - se_resnet101, - se_resnet152, - se_resnext50_32x4d, - se_resnext101_32x4d, - senet154, -) +from monai.networks.nets import SENet154, SEResNet50, SEResNet101, SEResNet152, SEResNext50, SEResNext101 from monai.utils import optional_import from tests.utils import test_pretrained_networks, test_script_save @@ -41,27 +32,14 @@ device = "cuda" if torch.cuda.is_available() else "cpu" NET_ARGS = {"spatial_dims": 3, "in_channels": 2, "num_classes": 2} -TEST_CASE_1 = [senet154, NET_ARGS] -TEST_CASE_2 = [se_resnet50, NET_ARGS] -TEST_CASE_3 = [se_resnet101, NET_ARGS] -TEST_CASE_4 = [se_resnet152, NET_ARGS] -TEST_CASE_5 = [se_resnext50_32x4d, NET_ARGS] -TEST_CASE_6 = [se_resnext101_32x4d, NET_ARGS] - -TEST_CASE_PRETRAINED_1 = [se_resnet50, {"spatial_dims": 2, "in_channels": 3, "num_classes": 2, "pretrained": True}] -TEST_CASE_PRETRAINED_2 = [ - { - "spatial_dims": 2, - "in_channels": 3, - "block": SEBottleneck, - "layers": [3, 8, 36, 3], - "groups": 64, - "reduction": 16, - "num_classes": 2, - "pretrained": True, - "pretrained_arch": "resnet50", - } -] +TEST_CASE_1 = [SENet154, NET_ARGS] +TEST_CASE_2 = [SEResNet50, NET_ARGS] +TEST_CASE_3 = [SEResNet101, NET_ARGS] +TEST_CASE_4 = [SEResNet152, NET_ARGS] +TEST_CASE_5 = [SEResNext50, NET_ARGS] +TEST_CASE_6 = [SEResNext101, NET_ARGS] + +TEST_CASE_PRETRAINED_1 = [SEResNet50, {"spatial_dims": 2, "in_channels": 3, "num_classes": 2, "pretrained": True}] class TestSENET(unittest.TestCase): @@ -107,11 +85,6 @@ def test_pretrain_consistency(self, model, input_param): # a conv layer with kernel size equals to 1. It may bring a little difference. self.assertTrue(torch.allclose(result, expected_result, rtol=1e-5, atol=1e-5)) - @parameterized.expand([TEST_CASE_PRETRAINED_2]) - def test_ill_pretrain(self, input_param): - with self.assertRaisesRegex(ValueError, ""): - net = SENet(**input_param) - if __name__ == "__main__": unittest.main() From d452d2d19b2e54b8e147121facfe418962212e46 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 20 Mar 2021 17:57:43 +0800 Subject: [PATCH 5/7] Fix type errors Signed-off-by: Yiheng Wang --- docs/source/networks.rst | 10 ---------- monai/networks/nets/senet.py | 16 ++++++++-------- tests/test_integration_classification_2d.py | 6 +++--- tests/test_vis_cam.py | 8 ++++---- tests/test_vis_gradcam.py | 8 ++++---- tests/test_vis_gradcampp.py | 8 ++++---- 6 files changed, 23 insertions(+), 33 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 036ba2aff7..f5d498a363 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -286,10 +286,6 @@ Nets ~~~~~~~~~~ .. autoclass:: DenseNet :members: -.. autofunction:: densenet121 -.. autofunction:: densenet169 -.. autofunction:: densenet201 -.. autofunction:: densenet264 `SegResNet` ~~~~~~~~~~~ @@ -305,12 +301,6 @@ Nets ~~~~~~~ .. autoclass:: SENet :members: -.. autofunction:: senet154 -.. autofunction:: se_resnet50 -.. autofunction:: se_resnet101 -.. autofunction:: se_resnet152 -.. autofunction:: se_resnext50_32x4d -.. autofunction:: se_resnext101_32x4d `HighResNet` ~~~~~~~~~~~~ diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index 704674a724..50627c1513 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -11,7 +11,7 @@ import re from collections import OrderedDict -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Any, List, Optional, Sequence, Tuple, Type, Union import torch import torch.nn as nn @@ -73,7 +73,7 @@ def __init__( spatial_dims: int, in_channels: int, block: Type[Union[SEBottleneck, SEResNetBottleneck, SEResNeXtBottleneck]], - layers: List[int], + layers: Sequence[int], groups: int, reduction: int, dropout_prob: Optional[float] = 0.2, @@ -306,7 +306,7 @@ def _load_state_dict(model, arch, progress): class SENet154(SENet): def __init__( self, - layers: List[int] = [3, 8, 36, 3], + layers: Sequence[int] = (3, 8, 36, 3), groups: int = 64, reduction: int = 16, pretrained: bool = False, @@ -328,7 +328,7 @@ def __init__( class SEResNet50(SENet): def __init__( self, - layers: List[int] = [3, 4, 6, 3], + layers: Sequence[int] = (3, 4, 6, 3), groups: int = 1, reduction: int = 16, dropout_prob: Optional[float] = None, @@ -358,7 +358,7 @@ def __init__( class SEResNet101(SENet): def __init__( self, - layers: List[int] = [3, 4, 23, 3], + layers: Sequence[int] = (3, 4, 23, 3), groups: int = 1, reduction: int = 16, inplanes: int = 64, @@ -386,7 +386,7 @@ def __init__( class SEResNet152(SENet): def __init__( self, - layers: List[int] = [3, 8, 36, 3], + layers: Sequence[int] = (3, 8, 36, 3), groups: int = 1, reduction: int = 16, inplanes: int = 64, @@ -414,7 +414,7 @@ def __init__( class SEResNext50(SENet): def __init__( self, - layers: List[int] = [3, 4, 6, 3], + layers: Sequence[int] = (3, 4, 6, 3), groups: int = 32, reduction: int = 16, dropout_prob: Optional[float] = None, @@ -444,7 +444,7 @@ def __init__( class SEResNext101(SENet): def __init__( self, - layers: List[int] = [3, 4, 23, 3], + layers: Sequence[int] = (3, 4, 23, 3), groups: int = 32, reduction: int = 16, dropout_prob: Optional[float] = None, diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py index 4be59cba41..6f8c949d78 100644 --- a/tests/test_integration_classification_2d.py +++ b/tests/test_integration_classification_2d.py @@ -22,7 +22,7 @@ from monai.apps import download_and_extract from monai.metrics import compute_roc_auc from monai.networks import eval_mode -from monai.networks.nets import densenet121 +from monai.networks.nets import DenseNet121 from monai.transforms import AddChannel, Compose, LoadImage, RandFlip, RandRotate, RandZoom, ScaleIntensity, ToTensor from monai.utils import set_determinism from tests.testing_data.integration_answers import test_integration_value @@ -71,7 +71,7 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", val_ds = MedNISTDataset(val_x, val_y, val_transforms) val_loader = DataLoader(val_ds, batch_size=300, num_workers=num_workers) - model = densenet121(spatial_dims=2, in_channels=1, out_channels=len(np.unique(train_y))).to(device) + model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=len(np.unique(train_y))).to(device) loss_function = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), 1e-5) epoch_num = 4 @@ -133,7 +133,7 @@ def run_inference_test(root_dir, test_x, test_y, device="cuda:0", num_workers=10 val_ds = MedNISTDataset(test_x, test_y, val_transforms) val_loader = DataLoader(val_ds, batch_size=300, num_workers=num_workers) - model = densenet121(spatial_dims=2, in_channels=1, out_channels=len(np.unique(test_y))).to(device) + model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=len(np.unique(test_y))).to(device) model_filename = os.path.join(root_dir, "best_metric_model.pth") model.load_state_dict(torch.load(model_filename)) diff --git a/tests/test_vis_cam.py b/tests/test_vis_cam.py index d400c27f02..47c116cd5d 100644 --- a/tests/test_vis_cam.py +++ b/tests/test_vis_cam.py @@ -14,7 +14,7 @@ import torch from parameterized import parameterized -from monai.networks.nets import DenseNet, densenet121, se_resnet50 +from monai.networks.nets import DenseNet, DenseNet121, SEResNet50 from monai.visualize import CAM # 2D @@ -68,15 +68,15 @@ class TestClassActivationMap(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, input_data, expected_shape): if input_data["model"] == "densenet2d": - model = densenet121(spatial_dims=2, in_channels=1, out_channels=3) + model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) if input_data["model"] == "densenet3d": model = DenseNet( spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,) ) if input_data["model"] == "senet2d": - model = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4) + model = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4) if input_data["model"] == "senet3d": - model = se_resnet50(spatial_dims=3, in_channels=3, num_classes=4) + model = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py index df47c4920e..f8e49f486f 100644 --- a/tests/test_vis_gradcam.py +++ b/tests/test_vis_gradcam.py @@ -15,7 +15,7 @@ import torch from parameterized import parameterized -from monai.networks.nets import DenseNet, densenet121, se_resnet50 +from monai.networks.nets import DenseNet, DenseNet121, SEResNet50 from monai.visualize import GradCAM # 2D @@ -65,15 +65,15 @@ class TestGradientClassActivationMap(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, input_data, expected_shape): if input_data["model"] == "densenet2d": - model = densenet121(spatial_dims=2, in_channels=1, out_channels=3) + model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) if input_data["model"] == "densenet3d": model = DenseNet( spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,) ) if input_data["model"] == "senet2d": - model = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4) + model = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4) if input_data["model"] == "senet3d": - model = se_resnet50(spatial_dims=3, in_channels=3, num_classes=4) + model = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() diff --git a/tests/test_vis_gradcampp.py b/tests/test_vis_gradcampp.py index fce68ccde0..92a4b2ac7b 100644 --- a/tests/test_vis_gradcampp.py +++ b/tests/test_vis_gradcampp.py @@ -14,7 +14,7 @@ import torch from parameterized import parameterized -from monai.networks.nets import DenseNet, densenet121, se_resnet50 +from monai.networks.nets import DenseNet, DenseNet121, SEResNet50 from monai.visualize import GradCAMpp # 2D @@ -64,15 +64,15 @@ class TestGradientClassActivationMapPP(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, input_data, expected_shape): if input_data["model"] == "densenet2d": - model = densenet121(spatial_dims=2, in_channels=1, out_channels=3) + model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) if input_data["model"] == "densenet3d": model = DenseNet( spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,) ) if input_data["model"] == "senet2d": - model = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4) + model = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4) if input_data["model"] == "senet3d": - model = se_resnet50(spatial_dims=3, in_channels=3, num_classes=4) + model = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() From b2dcd48fedf39a5b3cf67085f11429ec9ae0cdb7 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Sat, 20 Mar 2021 18:20:03 +0800 Subject: [PATCH 6/7] Fix name error Signed-off-by: Yiheng Wang --- tests/test_occlusion_sensitivity.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_occlusion_sensitivity.py b/tests/test_occlusion_sensitivity.py index 47a13d01e1..d58359a598 100644 --- a/tests/test_occlusion_sensitivity.py +++ b/tests/test_occlusion_sensitivity.py @@ -14,13 +14,13 @@ import torch from parameterized import parameterized -from monai.networks.nets import DenseNet, densenet121 +from monai.networks.nets import DenseNet, DenseNet121 from monai.visualize import OcclusionSensitivity device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") out_channels_2d = 4 out_channels_3d = 3 -model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=out_channels_2d).to(device) +model_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=out_channels_2d).to(device) model_3d = DenseNet( spatial_dims=3, in_channels=1, out_channels=out_channels_3d, init_features=2, growth_rate=2, block_config=(6,) ).to(device) From 9e751bcf964a9f66d266ae12b9be636f01ed2478 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 22 Mar 2021 14:11:41 +0800 Subject: [PATCH 7/7] Update docstring Signed-off-by: Yiheng Wang --- monai/networks/nets/densenet.py | 2 +- monai/networks/nets/senet.py | 2 +- monai/visualize/class_activation_maps.py | 8 ++++---- monai/visualize/occlusion_sensitivity.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py index d8ad8f89d3..280bc6b0cb 100644 --- a/monai/networks/nets/densenet.py +++ b/monai/networks/nets/densenet.py @@ -207,7 +207,7 @@ def _load_state_dict(model, arch, progress): "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", } - if arch in model_urls.keys(): + if arch in model_urls: model_url = model_urls[arch] else: raise ValueError( diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index 50627c1513..f5738edeeb 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -259,7 +259,7 @@ def _load_state_dict(model, arch, progress): "se_resnext50_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth", "se_resnext101_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth", } - if arch in model_urls.keys(): + if arch in model_urls: model_url = model_urls[arch] else: raise ValueError( diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index 6e93225af3..b310ec0834 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -209,10 +209,10 @@ class CAM(CAMBase): .. code-block:: python # densenet 2d - from monai.networks.nets import densenet121 + from monai.networks.nets import DenseNet121 from monai.visualize import CAM - model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3) + model_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) cam = CAM(nn_module=model_2d, target_layers="class_layers.relu", fc_layers="class_layers.out") result = cam(x=torch.rand((1, 1, 48, 64))) @@ -307,10 +307,10 @@ class GradCAM(CAMBase): .. code-block:: python # densenet 2d - from monai.networks.nets import densenet121 + from monai.networks.nets import DenseNet121 from monai.visualize import GradCAM - model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3) + model_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) cam = GradCAM(nn_module=model_2d, target_layers="class_layers.relu") result = cam(x=torch.rand((1, 1, 48, 64))) diff --git a/monai/visualize/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py index 5863614965..ee9a967da1 100644 --- a/monai/visualize/occlusion_sensitivity.py +++ b/monai/visualize/occlusion_sensitivity.py @@ -122,10 +122,10 @@ class OcclusionSensitivity: .. code-block:: python # densenet 2d - from monai.networks.nets import densenet121 + from monai.networks.nets import DenseNet121 from monai.visualize import OcclusionSensitivity - model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3) + model_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) occ_sens = OcclusionSensitivity(nn_module=model_2d) occ_map, most_probable_class = occ_sens(x=torch.rand((1, 1, 48, 64)), class_idx=None, b_box=[-1, -1, 2, 40, 1, 62])