Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet
from .generator import Generator
from .highresnet import HighResBlock, HighResNet
from .localnet import LocalNet
from .regressor import Regressor
from .regunet import RegUNet
from .regunet import LocalNet, RegUNet
from .segresnet import SegResNet, SegResNetVAE
from .senet import SENet, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d, senet154
from .unet import UNet, Unet, unet
Expand Down
129 changes: 0 additions & 129 deletions monai/networks/nets/localnet.py

This file was deleted.

91 changes: 91 additions & 0 deletions monai/networks/nets/regunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
from torch import nn
from torch.nn import functional as F

from monai.networks.blocks.regunet_block import (
RegistrationDownSampleBlock,
Expand Down Expand Up @@ -247,3 +248,93 @@ def forward(self, x):

out = self.output_block(outs, image_size=image_size)
return out


class AdditiveUpSampleBlock(nn.Module):
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
):
super(AdditiveUpSampleBlock, self).__init__()
self.deconv = get_deconv_block(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels)

def forward(self, x: torch.Tensor) -> torch.Tensor:
output_size = (size * 2 for size in x.shape[2:])
deconved = self.deconv(x)
resized = F.interpolate(x, output_size)
resized = torch.sum(torch.stack(resized.split(split_size=resized.shape[1] // 2, dim=1), dim=-1), dim=-1)
out: torch.Tensor = deconved + resized
return out


class LocalNet(RegUNet):
"""
Reimplementation of LocalNet, based on:
`Weakly-supervised convolutional neural networks for multimodal image registration
<https://doi.org/10.1016/j.media.2018.07.002>`_.
`Label-driven weakly-supervised learning for multimodal deformable image registration
<https://arxiv.org/abs/1711.01666>`_.

Adapted from:
DeepReg (https://github.com/DeepRegNet/DeepReg)
"""

def __init__(
self,
spatial_dims: int,
in_channels: int,
num_channel_initial: int,
extract_levels: Tuple[int],
out_kernel_initializer: Optional[str] = "kaiming_uniform",
out_activation: Optional[str] = None,
out_channels: int = 3,
pooling: bool = True,
concat_skip: bool = False,
):
"""
Args:
spatial_dims: number of spatial dims
in_channels: number of input channels
num_channel_initial: number of initial channels
out_kernel_initializer: kernel initializer for the last layer
out_activation: activation at the last layer
out_channels: number of channels for the output
extract_levels: list, which levels from net to extract. The maximum level must equal to ``depth``
pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv3d
concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition
"""
super().__init__(
spatial_dims=spatial_dims,
in_channels=in_channels,
num_channel_initial=num_channel_initial,
depth=max(extract_levels),
out_kernel_initializer=out_kernel_initializer,
out_activation=out_activation,
out_channels=out_channels,
pooling=pooling,
concat_skip=concat_skip,
encode_kernel_sizes=[7] + [3] * max(extract_levels),
)

def build_bottom_block(self, in_channels: int, out_channels: int):
kernel_size = self.encode_kernel_sizes[self.depth]
return get_conv_block(
spatial_dims=self.spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
)

def build_up_sampling_block(
self,
in_channels: int,
out_channels: int,
) -> nn.Module:
if self._use_additive_upsampling:
return AdditiveUpSampleBlock(
spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels
)

return get_deconv_block(spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels)
58 changes: 24 additions & 34 deletions tests/test_localnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets.localnet import LocalNet
from monai.networks.nets.regunet import LocalNet
from tests.utils import test_script_save

device = "cuda" if torch.cuda.is_available() else "cpu"
Expand All @@ -15,39 +15,36 @@
{
"spatial_dims": 2,
"in_channels": 2,
"out_channels": 2,
"num_channel_initial": 16,
"extract_levels": [0, 1, 2],
"out_activation": act,
"out_kernel_initializer": "kaiming_uniform",
"out_activation": None,
"out_channels": 2,
"extract_levels": (0, 1),
"pooling": False,
"concat_skip": True,
},
(1, 2, 16, 16),
(1, 2, 16, 16),
]
for act in ["sigmoid", None]
]

TEST_CASE_LOCALNET_3D = []
for in_channels in [2, 3]:
for out_channels in [1, 3]:
for num_channel_initial in [4, 16, 32]:
for extract_levels in [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]:
for out_activation in ["sigmoid", None]:
for out_initializer in ["kaiming_uniform", "zeros"]:
TEST_CASE_LOCALNET_3D.append(
[
{
"spatial_dims": 3,
"in_channels": in_channels,
"out_channels": out_channels,
"num_channel_initial": num_channel_initial,
"extract_levels": extract_levels,
"out_activation": out_activation,
"out_initializer": out_initializer,
},
(1, in_channels, 16, 16, 16),
(1, out_channels, 16, 16, 16),
]
)
TEST_CASE_LOCALNET_3D = [
[
{
"spatial_dims": 3,
"in_channels": 2,
"num_channel_initial": 16,
"out_kernel_initializer": "zeros",
"out_activation": "sigmoid",
"out_channels": 2,
"extract_levels": (0, 1, 2, 3),
"pooling": True,
"concat_skip": False,
},
(1, 2, 16, 16, 16),
(1, 2, 16, 16, 16),
]
]


class TestLocalNet(unittest.TestCase):
Expand All @@ -58,13 +55,6 @@ def test_shape(self, input_param, input_shape, expected_shape):
result = net(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

def test_ill_shape(self):
with self.assertRaisesRegex(ValueError, ""):
input_param, _, _ = TEST_CASE_LOCALNET_2D[0]
input_shape = (1, input_param["in_channels"], 17, 17)
net = LocalNet(**input_param).to(device)
net.forward(torch.randn(input_shape).to(device))

def test_script(self):
input_param, input_shape, _ = TEST_CASE_LOCALNET_2D[0]
net = LocalNet(**input_param)
Expand Down