Skip to content
Merged
5 changes: 5 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,11 @@ Nets
.. autoclass:: Critic
:members:

`TorchVisionFullyConvModel`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: TorchVisionFullyConvModel
:members:

Utilities
---------
.. automodule:: monai.networks.utils
Expand Down
1 change: 1 addition & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .regunet import GlobalNet, LocalNet, RegUNet
from .segresnet import SegResNet, SegResNetVAE
from .senet import SENet, SENet154, SEResNet50, SEResNet101, SEResNet152, SEResNext50, SEResNext101
from .torchvision_fc import TorchVisionFullyConvModel
from .unet import UNet, Unet, unet
from .varautoencoder import VarAutoEncoder
from .vnet import VNet
67 changes: 67 additions & 0 deletions monai/networks/nets/torchvision_fc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Tuple, Union

import torch

from monai.utils import optional_import

models, _ = optional_import("torchvision.models")


class TorchVisionFullyConvModel(torch.nn.Module):
"""
Customize TorchVision models to replace fully connected layer by convolutional layer.

Args:
model_name: name of any torchvision with adaptive avg pooling and fully connected layer at the end.
- resnet18 (default)
- resnet34
- resnet50
- resnet101
- resnet152
- resnext50_32x4d
- resnext101_32x8d
- wide_resnet50_2
- wide_resnet101_2
n_classes: number of classes for the last classification layer. Default to 1.
pool_size: the kernel size for `AvgPool2d` to replace `AdaptiveAvgPool2d`. Default to (7, 7).
pool_stride: the stride for `AvgPool2d` to replace `AdaptiveAvgPool2d`. Default to 1.
pretrained: whether to use the imagenet pretrained weights. Default to False.
"""

def __init__(
self,
model_name: str = "resnet18",
n_classes: int = 1,
pool_size: Union[int, Tuple[int, int]] = (7, 7),
pool_stride: Union[int, Tuple[int, int]] = 1,
pretrained: bool = False,
):
super().__init__()
model = getattr(models, model_name)(pretrained=pretrained)
layers = list(model.children())

# check if the model is compatible
if not str(layers[-1]).startswith("Linear"):
raise ValueError(f"Model ['{model_name}'] does not have a Linear layer at the end.")
if not str(layers[-2]).startswith("AdaptiveAvgPool2d"):
raise ValueError(f"Model ['{model_name}'] does not have a AdaptiveAvgPool2d layer next to the end.")

# remove the last Linear layer (fully connected) and the adaptive avg pooling
self.features = torch.nn.Sequential(*layers[:-2])

# add 7x7 avg pooling (in place of adaptive avg pooling)
self.pool = torch.nn.AvgPool2d(kernel_size=pool_size, stride=pool_stride)

# add 1x1 conv (it behaves like a FC layer)
self.fc = torch.nn.Conv2d(model.fc.in_features, n_classes, kernel_size=(1, 1))
Comment thread
wyli marked this conversation as resolved.

def forward(self, x):
x = self.features(x)

# apply 2D avg pooling
x = self.pool(x)

# apply last 1x1 conv layer that act like a linear layer
x = self.fc(x)

return x
106 changes: 106 additions & 0 deletions tests/test_torchvision_fc_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from unittest import skipUnless

import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets import TorchVisionFullyConvModel
from monai.utils import optional_import

_, has_tv = optional_import("torchvision")

device = "cuda" if torch.cuda.is_available() else "cpu"

TEST_CASE_0 = [
{"model_name": "resnet18", "n_classes": 1, "pretrained": False},
(2, 3, 224, 224),
(2, 1, 1, 1),
]

TEST_CASE_1 = [
{"model_name": "resnet18", "n_classes": 1, "pretrained": False},
(2, 3, 256, 256),
(2, 1, 2, 2),
]

TEST_CASE_2 = [
{"model_name": "resnet101", "n_classes": 5, "pretrained": False},
(2, 3, 256, 256),
(2, 5, 2, 2),
]

TEST_CASE_3 = [
{"model_name": "resnet101", "n_classes": 5, "pool_size": 6, "pretrained": False},
(2, 3, 224, 224),
(2, 5, 2, 2),
]

TEST_CASE_PRETRAINED_0 = [
{"model_name": "resnet18", "n_classes": 1, "pretrained": True},
(2, 3, 224, 224),
(2, 1, 1, 1),
-0.010419349186122417,
]

TEST_CASE_PRETRAINED_1 = [
{"model_name": "resnet18", "n_classes": 1, "pretrained": True},
(2, 3, 256, 256),
(2, 1, 2, 2),
-0.010419349186122417,
]

TEST_CASE_PRETRAINED_2 = [
{"model_name": "resnet18", "n_classes": 5, "pretrained": True},
(2, 3, 256, 256),
(2, 5, 2, 2),
-0.010419349186122417,
]


class TestTorchVisionFullyConvModel(unittest.TestCase):
@parameterized.expand(
[
TEST_CASE_0,
TEST_CASE_1,
TEST_CASE_2,
TEST_CASE_3,
]
)
@skipUnless(has_tv, "Requires TorchVision.")
def test_without_pretrained(self, input_param, input_shape, expected_shape):
net = TorchVisionFullyConvModel(**input_param).to(device)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

@parameterized.expand(
[
TEST_CASE_PRETRAINED_0,
TEST_CASE_PRETRAINED_1,
TEST_CASE_PRETRAINED_2,
]
)
@skipUnless(has_tv, "Requires TorchVision.")
def test_with_pretrained(self, input_param, input_shape, expected_shape, expected_value):
net = TorchVisionFullyConvModel(**input_param).to(device)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
value = next(net.parameters())[0, 0, 0, 0].item()
self.assertEqual(value, expected_value)
self.assertEqual(result.shape, expected_shape)


if __name__ == "__main__":
unittest.main()