Skip to content
Merged
10 changes: 0 additions & 10 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,6 @@ Nets
~~~~~~~~~~
.. autoclass:: DenseNet
:members:
.. autofunction:: densenet121
.. autofunction:: densenet169
.. autofunction:: densenet201
.. autofunction:: densenet264

`SegResNet`
~~~~~~~~~~~
Expand All @@ -305,12 +301,6 @@ Nets
~~~~~~~
.. autoclass:: SENet
:members:
.. autofunction:: senet154
.. autofunction:: se_resnet50
.. autofunction:: se_resnet101
.. autofunction:: se_resnet152
.. autofunction:: se_resnext50_32x4d
.. autofunction:: se_resnext101_32x4d

`HighResNet`
~~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
from .autoencoder import AutoEncoder
from .basic_unet import BasicUNet, BasicUnet, Basicunet
from .classifier import Classifier, Critic, Discriminator
from .densenet import DenseNet, densenet121, densenet169, densenet201, densenet264
from .densenet import DenseNet, DenseNet121, DenseNet169, DenseNet201, DenseNet264
from .dynunet import DynUNet, DynUnet, Dynunet
from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet
from .generator import Generator
from .highresnet import HighResBlock, HighResNet
from .regressor import Regressor
from .regunet import GlobalNet, LocalNet, RegUNet
from .segresnet import SegResNet, SegResNetVAE
from .senet import SENet, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d, senet154
from .senet import SENet, SENet154, SEResNet50, SEResNet101, SEResNet152, SEResNext50, SEResNext101
from .unet import UNet, Unet, unet
from .varautoencoder import VarAutoEncoder
from .vnet import VNet
198 changes: 105 additions & 93 deletions monai/networks/nets/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,6 @@ class DenseNet(nn.Module):
bn_size: multiplicative factor for number of bottle neck layers.
(i.e. bn_size * k features in the bottleneck layer)
dropout_prob: dropout rate after each dense layer.
pretrained: whether to load ImageNet pretrained weights when `spatial_dims == 2`.
In order to load weights correctly, Please ensure that the `block_config`
is consistent with the corresponding arch.
pretrained_arch: the arch name for pretrained weights.
progress: If True, displays a progress bar of the download to stderr.
"""

def __init__(
Expand All @@ -132,9 +127,6 @@ def __init__(
block_config: Sequence[int] = (6, 12, 24, 16),
bn_size: int = 4,
dropout_prob: float = 0.0,
pretrained: bool = False,
pretrained_arch: str = "densenet121",
progress: bool = True,
) -> None:

super(DenseNet, self).__init__()
Expand Down Expand Up @@ -198,107 +190,127 @@ def __init__(
elif isinstance(m, nn.Linear):
nn.init.constant_(torch.as_tensor(m.bias), 0)

if pretrained:
self._load_state_dict(pretrained_arch, progress)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = self.class_layers(x)
return x

def _load_state_dict(self, arch, progress):
"""
This function is used to load pretrained models.
Adapted from `PyTorch Hub 2D version
<https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py>`_
"""
model_urls = {
"densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
"densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
"densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
}
if arch in model_urls.keys():
model_url = model_urls[arch]
else:
raise ValueError(
"only 'densenet121', 'densenet169' and 'densenet201' are supported to load pretrained weights."
)
pattern = re.compile(
r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"

def _load_state_dict(model, arch, progress):
"""
This function is used to load pretrained models.
Adapted from `PyTorch Hub 2D version
<https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py>`_
"""
model_urls = {
"densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
"densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
"densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
}
if arch in model_urls:
model_url = model_urls[arch]
else:
raise ValueError(
"only 'densenet121', 'densenet169' and 'densenet201' are supported to load pretrained weights."
)
pattern = re.compile(
r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
)

state_dict = load_state_dict_from_url(model_url, progress=progress)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + ".layers" + res.group(2) + res.group(3)
state_dict[new_key] = state_dict[key]
del state_dict[key]
state_dict = load_state_dict_from_url(model_url, progress=progress)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + ".layers" + res.group(2) + res.group(3)
state_dict[new_key] = state_dict[key]
del state_dict[key]

model_dict = self.state_dict()
state_dict = {
k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape)
}
model_dict.update(state_dict)
self.load_state_dict(model_dict)
model_dict = model.state_dict()
state_dict = {
k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape)
}
model_dict.update(state_dict)
model.load_state_dict(model_dict)


def densenet121(pretrained: bool = False, progress: bool = True, **kwargs) -> DenseNet:
"""
when `spatial_dims = 2`, specify `pretrained = True` can load Imagenet pretrained weights achieved
from `PyTorch Hub 2D version
<https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py>`_
"""
model = DenseNet(
init_features=64,
growth_rate=32,
block_config=(6, 12, 24, 16),
pretrained=pretrained,
pretrained_arch="densenet121",
progress=progress,
class DenseNet121(DenseNet):
def __init__(
self,
init_features: int = 64,
growth_rate: int = 32,
block_config: Sequence[int] = (6, 12, 24, 16),
pretrained: bool = False,
progress: bool = True,
**kwargs,
)
return model
) -> None:
super(DenseNet121, self).__init__(
init_features=init_features,
growth_rate=growth_rate,
block_config=block_config,
**kwargs,
)
if pretrained:
# it only worked when `spatial_dims` is 2
_load_state_dict(self, "densenet121", progress)


def densenet169(pretrained: bool = False, progress: bool = True, **kwargs) -> DenseNet:
"""
when `spatial_dims = 2`, specify `pretrained = True` can load Imagenet pretrained weights achieved
from `PyTorch Hub 2D version
<https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py>`_
"""
model = DenseNet(
init_features=64,
growth_rate=32,
block_config=(6, 12, 32, 32),
pretrained=pretrained,
pretrained_arch="densenet169",
progress=progress,
class DenseNet169(DenseNet):
def __init__(
self,
init_features: int = 64,
growth_rate: int = 32,
block_config: Sequence[int] = (6, 12, 32, 32),
pretrained: bool = False,
progress: bool = True,
**kwargs,
)
return model
) -> None:
super(DenseNet169, self).__init__(
init_features=init_features,
growth_rate=growth_rate,
block_config=block_config,
**kwargs,
)
if pretrained:
# it only worked when `spatial_dims` is 2
_load_state_dict(self, "densenet169", progress)


def densenet201(pretrained: bool = False, progress: bool = True, **kwargs) -> DenseNet:
"""
when `spatial_dims = 2`, specify `pretrained = True` can load Imagenet pretrained weights achieved
from `PyTorch Hub 2D version
<https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py>`_
"""
model = DenseNet(
init_features=64,
growth_rate=32,
block_config=(6, 12, 48, 32),
pretrained=pretrained,
pretrained_arch="densenet201",
progress=progress,
class DenseNet201(DenseNet):
def __init__(
self,
init_features: int = 64,
growth_rate: int = 32,
block_config: Sequence[int] = (6, 12, 48, 32),
pretrained: bool = False,
progress: bool = True,
**kwargs,
)
return model
) -> None:
super(DenseNet201, self).__init__(
init_features=init_features,
growth_rate=growth_rate,
block_config=block_config,
**kwargs,
)
if pretrained:
# it only worked when `spatial_dims` is 2
_load_state_dict(self, "densenet201", progress)


def densenet264(pretrained: bool = False, progress: bool = True, **kwargs) -> DenseNet:
model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 64, 48), **kwargs)
if pretrained:
print("Currently PyTorch Hub does not provide densenet264 pretrained models.")
return model
class DenseNet264(DenseNet):
def __init__(
self,
init_features: int = 64,
growth_rate: int = 32,
block_config: Sequence[int] = (6, 12, 48, 32),
pretrained: bool = False,
progress: bool = True,
**kwargs,
) -> None:
super(DenseNet264, self).__init__(
init_features=init_features,
growth_rate=growth_rate,
block_config=block_config,
**kwargs,
)
if pretrained:
print("Currently PyTorch Hub does not provide densenet264 pretrained models.")
Loading