Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,11 @@ Nets
.. autoclass:: RegUNet
:members:

`GlobalNet`
~~~~~~~~~~~~
.. autoclass:: GlobalNet
:members:

`LocalNet`
~~~~~~~~~~~
.. autoclass:: LocalNet
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .generator import Generator
from .highresnet import HighResBlock, HighResNet
from .regressor import Regressor
from .regunet import LocalNet, RegUNet
from .regunet import GlobalNet, LocalNet, RegUNet
from .segresnet import SegResNet, SegResNetVAE
from .senet import SENet, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d, senet154
from .unet import UNet, Unet, unet
Expand Down
104 changes: 104 additions & 0 deletions monai/networks/nets/regunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,110 @@ def forward(self, x):
return out


class AffineHead(nn.Module):
def __init__(
self,
spatial_dims: int,
image_size: List[int],
decode_size: List[int],
in_channels: int,
):
super(AffineHead, self).__init__()
self.spatial_dims = spatial_dims
if spatial_dims == 2:
in_features = in_channels * decode_size[0] * decode_size[1]
out_features = 6
elif spatial_dims == 3:
in_features = in_channels * decode_size[0] * decode_size[1] * decode_size[2]
out_features = 12
else:
raise ValueError(f"only support 2D/3D operation, got spatial_dims={spatial_dims}")

self.fc = nn.Linear(in_features=in_features, out_features=out_features)
self.grid = self.get_reference_grid(image_size) # (spatial_dims, ...)

@staticmethod
def get_reference_grid(image_size: Union[Tuple[int], List[int]]) -> torch.Tensor:
mesh_points = [torch.arange(0, dim) for dim in image_size]
grid = torch.stack(torch.meshgrid(*mesh_points), dim=0) # (spatial_dims, ...)
return grid.to(dtype=torch.float)

def affine_transform(self, theta: torch.Tensor):
# (spatial_dims, ...) -> (spatial_dims + 1, ...)
grid_padded = torch.cat([self.grid, torch.ones_like(self.grid[:1])])
Comment thread
wyli marked this conversation as resolved.

# grid_warped[b,p,...] = sum_over_q(grid_padded[q,...] * theta[b,p,q]
if self.spatial_dims == 2:
grid_warped = torch.einsum("qij,bpq->bpij", grid_padded, theta.reshape(-1, 2, 3))
elif self.spatial_dims == 3:
grid_warped = torch.einsum("qijk,bpq->bpijk", grid_padded, theta.reshape(-1, 3, 4))
else:
raise ValueError(f"do not support spatial_dims={self.spatial_dims}")
return grid_warped

def forward(self, x: List[torch.Tensor], image_size: List[int]) -> torch.Tensor:
f = x[0]
self.grid = self.grid.to(device=f.device)
theta = self.fc(f.reshape(f.shape[0], -1))
out: torch.Tensor = self.affine_transform(theta) - self.grid
return out


class GlobalNet(RegUNet):
"""
Build GlobalNet for image registration.

Reference:
Hu, Yipeng, et al.
"Label-driven weakly-supervised learning
for multimodal deformable image registration,"
https://arxiv.org/abs/1711.01666
"""

def __init__(
self,
image_size: List[int],
spatial_dims: int,
in_channels: int,
num_channel_initial: int,
depth: int,
out_kernel_initializer: Optional[str] = "kaiming_uniform",
out_activation: Optional[str] = None,
pooling: bool = True,
concat_skip: bool = False,
encode_kernel_sizes: Union[int, List[int]] = 3,
):
for size in image_size:
if size % (2 ** depth) != 0:
raise ValueError(
f"given depth {depth}, "
f"all input spatial dimension must be divisible by {2 ** depth}, "
f"got input of size {image_size}"
)
self.image_size = image_size
self.decode_size = [size // (2 ** depth) for size in image_size]
super().__init__(
spatial_dims=spatial_dims,
in_channels=in_channels,
num_channel_initial=num_channel_initial,
depth=depth,
out_kernel_initializer=out_kernel_initializer,
out_activation=out_activation,
out_channels=spatial_dims,
pooling=pooling,
concat_skip=concat_skip,
encode_kernel_sizes=encode_kernel_sizes,
)

def build_output_block(self):
return AffineHead(
spatial_dims=self.spatial_dims,
image_size=self.image_size,
decode_size=self.decode_size,
in_channels=self.num_channels[-1],
)


class AdditiveUpSampleBlock(nn.Module):
def __init__(
self,
Expand Down
79 changes: 79 additions & 0 deletions tests/test_globalnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets import GlobalNet
from monai.networks.nets.regunet import AffineHead
from tests.utils import test_script_save

TEST_CASES_AFFINE_TRANSFORM = [
[
{"spatial_dims": 3, "image_size": (2, 2, 2), "decode_size": (2, 2, 2), "in_channels": 1},
torch.ones(2, 12),
torch.tensor([[[1, 2], [2, 3]], [[2, 3], [3, 4]]]).unsqueeze(0).unsqueeze(0).expand(2, 3, 2, 2, 2),
],
[
{"spatial_dims": 3, "image_size": (2, 2, 2), "decode_size": (2, 2, 2), "in_channels": 1},
torch.arange(1, 13).reshape(1, 12).to(torch.float),
torch.tensor(
[
[[[4.0, 7.0], [6.0, 9.0]], [[5.0, 8.0], [7.0, 10.0]]],
[[[8.0, 15.0], [14.0, 21.0]], [[13.0, 20.0], [19.0, 26.0]]],
[[[12.0, 23.0], [22.0, 33.0]], [[21.0, 32.0], [31.0, 42.0]]],
]
).unsqueeze(0),
],
]


TEST_CASES_GLOBAL_NET = [
[
{
"image_size": (16, 16),
"spatial_dims": 2,
"in_channels": 1,
"num_channel_initial": 16,
"depth": 1,
"out_kernel_initializer": "kaiming_uniform",
"out_activation": None,
"pooling": True,
"concat_skip": True,
"encode_kernel_sizes": 3,
},
(1, 1, 16, 16),
(1, 2, 16, 16),
]
]


class TestAffineHead(unittest.TestCase):
@parameterized.expand(TEST_CASES_AFFINE_TRANSFORM)
def test_shape(self, input_param, theta, expected_val):
layer = AffineHead(**input_param)
result = layer.affine_transform(theta)
np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4)


device = "cuda" if torch.cuda.is_available() else "cpu"


class TestGlobalNet(unittest.TestCase):
@parameterized.expand(TEST_CASES_GLOBAL_NET)
def test_shape(self, input_param, input_shape, expected_shape):
net = GlobalNet(**input_param).to(device)
with eval_mode(net):
result = net(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

def test_script(self):
input_param, input_shape, _ = TEST_CASES_GLOBAL_NET[0]
net = GlobalNet(**input_param)
test_data = torch.randn(input_shape)
test_script_save(net, test_data)


if __name__ == "__main__":
unittest.main()