diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 5688f4b143..036ba2aff7 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -350,6 +350,11 @@ Nets .. autoclass:: RegUNet :members: +`GlobalNet` +~~~~~~~~~~~~ +.. autoclass:: GlobalNet + :members: + `LocalNet` ~~~~~~~~~~~ .. autoclass:: LocalNet diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 7a39872525..f3def30736 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -19,7 +19,7 @@ from .generator import Generator from .highresnet import HighResBlock, HighResNet from .regressor import Regressor -from .regunet import LocalNet, RegUNet +from .regunet import GlobalNet, LocalNet, RegUNet from .segresnet import SegResNet, SegResNetVAE from .senet import SENet, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d, senet154 from .unet import UNet, Unet, unet diff --git a/monai/networks/nets/regunet.py b/monai/networks/nets/regunet.py index 3263a6b5bc..25455c2df7 100644 --- a/monai/networks/nets/regunet.py +++ b/monai/networks/nets/regunet.py @@ -250,6 +250,110 @@ def forward(self, x): return out +class AffineHead(nn.Module): + def __init__( + self, + spatial_dims: int, + image_size: List[int], + decode_size: List[int], + in_channels: int, + ): + super(AffineHead, self).__init__() + self.spatial_dims = spatial_dims + if spatial_dims == 2: + in_features = in_channels * decode_size[0] * decode_size[1] + out_features = 6 + elif spatial_dims == 3: + in_features = in_channels * decode_size[0] * decode_size[1] * decode_size[2] + out_features = 12 + else: + raise ValueError(f"only support 2D/3D operation, got spatial_dims={spatial_dims}") + + self.fc = nn.Linear(in_features=in_features, out_features=out_features) + self.grid = self.get_reference_grid(image_size) # (spatial_dims, ...) + + @staticmethod + def get_reference_grid(image_size: Union[Tuple[int], List[int]]) -> torch.Tensor: + mesh_points = [torch.arange(0, dim) for dim in image_size] + grid = torch.stack(torch.meshgrid(*mesh_points), dim=0) # (spatial_dims, ...) + return grid.to(dtype=torch.float) + + def affine_transform(self, theta: torch.Tensor): + # (spatial_dims, ...) -> (spatial_dims + 1, ...) + grid_padded = torch.cat([self.grid, torch.ones_like(self.grid[:1])]) + + # grid_warped[b,p,...] = sum_over_q(grid_padded[q,...] * theta[b,p,q] + if self.spatial_dims == 2: + grid_warped = torch.einsum("qij,bpq->bpij", grid_padded, theta.reshape(-1, 2, 3)) + elif self.spatial_dims == 3: + grid_warped = torch.einsum("qijk,bpq->bpijk", grid_padded, theta.reshape(-1, 3, 4)) + else: + raise ValueError(f"do not support spatial_dims={self.spatial_dims}") + return grid_warped + + def forward(self, x: List[torch.Tensor], image_size: List[int]) -> torch.Tensor: + f = x[0] + self.grid = self.grid.to(device=f.device) + theta = self.fc(f.reshape(f.shape[0], -1)) + out: torch.Tensor = self.affine_transform(theta) - self.grid + return out + + +class GlobalNet(RegUNet): + """ + Build GlobalNet for image registration. + + Reference: + Hu, Yipeng, et al. + "Label-driven weakly-supervised learning + for multimodal deformable image registration," + https://arxiv.org/abs/1711.01666 + """ + + def __init__( + self, + image_size: List[int], + spatial_dims: int, + in_channels: int, + num_channel_initial: int, + depth: int, + out_kernel_initializer: Optional[str] = "kaiming_uniform", + out_activation: Optional[str] = None, + pooling: bool = True, + concat_skip: bool = False, + encode_kernel_sizes: Union[int, List[int]] = 3, + ): + for size in image_size: + if size % (2 ** depth) != 0: + raise ValueError( + f"given depth {depth}, " + f"all input spatial dimension must be divisible by {2 ** depth}, " + f"got input of size {image_size}" + ) + self.image_size = image_size + self.decode_size = [size // (2 ** depth) for size in image_size] + super().__init__( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_channel_initial=num_channel_initial, + depth=depth, + out_kernel_initializer=out_kernel_initializer, + out_activation=out_activation, + out_channels=spatial_dims, + pooling=pooling, + concat_skip=concat_skip, + encode_kernel_sizes=encode_kernel_sizes, + ) + + def build_output_block(self): + return AffineHead( + spatial_dims=self.spatial_dims, + image_size=self.image_size, + decode_size=self.decode_size, + in_channels=self.num_channels[-1], + ) + + class AdditiveUpSampleBlock(nn.Module): def __init__( self, diff --git a/tests/test_globalnet.py b/tests/test_globalnet.py new file mode 100644 index 0000000000..19e9db9137 --- /dev/null +++ b/tests/test_globalnet.py @@ -0,0 +1,79 @@ +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import GlobalNet +from monai.networks.nets.regunet import AffineHead +from tests.utils import test_script_save + +TEST_CASES_AFFINE_TRANSFORM = [ + [ + {"spatial_dims": 3, "image_size": (2, 2, 2), "decode_size": (2, 2, 2), "in_channels": 1}, + torch.ones(2, 12), + torch.tensor([[[1, 2], [2, 3]], [[2, 3], [3, 4]]]).unsqueeze(0).unsqueeze(0).expand(2, 3, 2, 2, 2), + ], + [ + {"spatial_dims": 3, "image_size": (2, 2, 2), "decode_size": (2, 2, 2), "in_channels": 1}, + torch.arange(1, 13).reshape(1, 12).to(torch.float), + torch.tensor( + [ + [[[4.0, 7.0], [6.0, 9.0]], [[5.0, 8.0], [7.0, 10.0]]], + [[[8.0, 15.0], [14.0, 21.0]], [[13.0, 20.0], [19.0, 26.0]]], + [[[12.0, 23.0], [22.0, 33.0]], [[21.0, 32.0], [31.0, 42.0]]], + ] + ).unsqueeze(0), + ], +] + + +TEST_CASES_GLOBAL_NET = [ + [ + { + "image_size": (16, 16), + "spatial_dims": 2, + "in_channels": 1, + "num_channel_initial": 16, + "depth": 1, + "out_kernel_initializer": "kaiming_uniform", + "out_activation": None, + "pooling": True, + "concat_skip": True, + "encode_kernel_sizes": 3, + }, + (1, 1, 16, 16), + (1, 2, 16, 16), + ] +] + + +class TestAffineHead(unittest.TestCase): + @parameterized.expand(TEST_CASES_AFFINE_TRANSFORM) + def test_shape(self, input_param, theta, expected_val): + layer = AffineHead(**input_param) + result = layer.affine_transform(theta) + np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) + + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +class TestGlobalNet(unittest.TestCase): + @parameterized.expand(TEST_CASES_GLOBAL_NET) + def test_shape(self, input_param, input_shape, expected_shape): + net = GlobalNet(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_script(self): + input_param, input_shape, _ = TEST_CASES_GLOBAL_NET[0] + net = GlobalNet(**input_param) + test_data = torch.randn(input_shape) + test_script_save(net, test_data) + + +if __name__ == "__main__": + unittest.main()