diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 720a3723dc..7607cd2701 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -473,6 +473,11 @@ Nets .. autoclass:: Unet .. autoclass:: unet +`AttentionUnet` +~~~~~~~~~~~~~~~ +.. autoclass:: AttentionUnet + :members: + `UNETR` ~~~~~~~ .. autoclass:: UNETR diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 22fcef4903..16686fa25c 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -10,6 +10,7 @@ # limitations under the License. from .ahnet import AHnet, Ahnet, AHNet +from .attentionunet import AttentionUnet from .autoencoder import AutoEncoder from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet from .classifier import Classifier, Critic, Discriminator diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py new file mode 100644 index 0000000000..177a54e105 --- /dev/null +++ b/monai/networks/nets/attentionunet.py @@ -0,0 +1,257 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence, Union + +import torch +import torch.nn as nn + +from monai.networks.blocks.convolutions import Convolution +from monai.networks.layers.factories import Norm + +__all__ = ["AttentionUnet"] + + +class ConvBlock(nn.Module): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + strides: int = 1, + dropout=0.0, + ): + super().__init__() + layers = [ + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + strides=strides, + padding=None, + adn_ordering="NDA", + act="relu", + norm=Norm.BATCH, + dropout=dropout, + ), + Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + strides=1, + padding=None, + adn_ordering="NDA", + act="relu", + norm=Norm.BATCH, + dropout=dropout, + ), + ] + self.conv = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_c: torch.Tensor = self.conv(x) + return x_c + + +class UpConv(nn.Module): + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size=3, strides=2, dropout=0.0): + super().__init__() + self.up = Convolution( + spatial_dims, + in_channels, + out_channels, + strides=strides, + kernel_size=kernel_size, + act="relu", + adn_ordering="NDA", + norm=Norm.BATCH, + dropout=dropout, + is_transposed=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_u: torch.Tensor = self.up(x) + return x_u + + +class AttentionBlock(nn.Module): + def __init__(self, spatial_dims: int, f_int: int, f_g: int, f_l: int, dropout=0.0): + super().__init__() + self.W_g = nn.Sequential( + Convolution( + spatial_dims=spatial_dims, + in_channels=f_g, + out_channels=f_int, + kernel_size=1, + strides=1, + padding=0, + dropout=dropout, + conv_only=True, + ), + Norm[Norm.BATCH, spatial_dims](f_int), + ) + + self.W_x = nn.Sequential( + Convolution( + spatial_dims=spatial_dims, + in_channels=f_l, + out_channels=f_int, + kernel_size=1, + strides=1, + padding=0, + dropout=dropout, + conv_only=True, + ), + Norm[Norm.BATCH, spatial_dims](f_int), + ) + + self.psi = nn.Sequential( + Convolution( + spatial_dims=spatial_dims, + in_channels=f_int, + out_channels=1, + kernel_size=1, + strides=1, + padding=0, + dropout=dropout, + conv_only=True, + ), + Norm[Norm.BATCH, spatial_dims](1), + nn.Sigmoid(), + ) + + self.relu = nn.ReLU() + + def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + g1 = self.W_g(g) + x1 = self.W_x(x) + psi: torch.Tensor = self.relu(g1 + x1) + psi = self.psi(psi) + + return x * psi + + +class AttentionLayer(nn.Module): + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, submodule: nn.Module, dropout=0.0): + super().__init__() + self.attention = AttentionBlock( + spatial_dims=spatial_dims, f_g=in_channels, f_l=in_channels, f_int=in_channels // 2 + ) + self.upconv = UpConv(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=in_channels, strides=2) + self.merge = Convolution( + spatial_dims=spatial_dims, in_channels=2 * in_channels, out_channels=in_channels, dropout=dropout + ) + self.submodule = submodule + + def forward(self, x: torch.Tensor) -> torch.Tensor: + fromlower = self.upconv(self.submodule(x)) + att = self.attention(g=fromlower, x=x) + att_m: torch.Tensor = self.merge(torch.cat((att, fromlower), dim=1)) + return att_m + + +class AttentionUnet(nn.Module): + """ + Attention Unet based on + Otkay et al. "Attention U-Net: Learning Where to Look for the Pancreas" + https://arxiv.org/abs/1804.03999 + + Args: + spatial_dims: number of spatial dimensions of the input image. + in_channels: number of the input channel. + out_channels: number of the output classes. + channels (Sequence[int]): sequence of channels. Top block first. The length of `channels` should be no less than 2. + strides (Sequence[int]): stride to use for convolutions. + kernel_size: convolution kernel size. + upsample_kernel_size: convolution kernel size for transposed convolution layers. + dropout: dropout ratio. Defaults to no dropout. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int], + strides: Sequence[int], + kernel_size: Union[Sequence[int], int] = 3, + up_kernel_size: Union[Sequence[int], int] = 3, + dropout: float = 0.0, + ): + super().__init__() + self.dimensions = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.strides = strides + self.kernel_size = kernel_size + self.dropout = dropout + + head = ConvBlock(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=channels[0], dropout=dropout) + reduce_channels = Convolution( + spatial_dims=spatial_dims, + in_channels=channels[0], + out_channels=out_channels, + kernel_size=1, + strides=1, + padding=0, + conv_only=True, + ) + self.up_kernel_size = up_kernel_size + + def _create_block(channels: Sequence[int], strides: Sequence[int], level: int = 0) -> nn.Module: + if len(channels) > 2: + subblock = _create_block(channels[1:], strides[1:], level=level + 1) + return AttentionLayer( + spatial_dims=spatial_dims, + in_channels=channels[0], + out_channels=channels[1], + submodule=nn.Sequential( + ConvBlock( + spatial_dims=spatial_dims, + in_channels=channels[0], + out_channels=channels[1], + strides=strides[0], + dropout=self.dropout, + ), + subblock, + ), + dropout=dropout, + ) + else: + # the next layer is the bottom so stop recursion, + # create the bottom layer as the sublock for this layer + return self._get_bottom_layer(channels[0], channels[1], strides[0], level=level + 1) + + encdec = _create_block(self.channels, self.strides) + self.model = nn.Sequential(head, encdec, reduce_channels) + + def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, level: int) -> nn.Module: + return AttentionLayer( + spatial_dims=self.dimensions, + in_channels=in_channels, + out_channels=out_channels, + submodule=ConvBlock( + spatial_dims=self.dimensions, + in_channels=in_channels, + out_channels=out_channels, + strides=strides, + dropout=self.dropout, + ), + dropout=self.dropout, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_m: torch.Tensor = self.model(x) + return x_m diff --git a/tests/test_attentionunet.py b/tests/test_attentionunet.py new file mode 100644 index 0000000000..b2f53f9c16 --- /dev/null +++ b/tests/test_attentionunet.py @@ -0,0 +1,65 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +import monai.networks.nets.attentionunet as att +from tests.utils import skip_if_no_cuda + + +class TestAttentionUnet(unittest.TestCase): + def test_attention_block(self): + for dims in [2, 3]: + block = att.AttentionBlock(dims, f_int=2, f_g=6, f_l=6) + shape = (4, 6) + (30,) * dims + x = torch.rand(*shape, dtype=torch.float32) + output = block(x, x) + self.assertEqual(output.shape, x.shape) + + block = att.AttentionBlock(dims, f_int=2, f_g=3, f_l=6) + xshape = (4, 6) + (30,) * dims + x = torch.rand(*xshape, dtype=torch.float32) + gshape = (4, 3) + (30,) * dims + g = torch.rand(*gshape, dtype=torch.float32) + output = block(g, x) + self.assertEqual(output.shape, x.shape) + + def test_attentionunet(self): + for dims in [2, 3]: + shape = (3, 1) + (92,) * dims + input = torch.rand(*shape) + model = att.AttentionUnet( + spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), strides=(2, 2) + ) + output = model(input) + self.assertEqual(output.shape[2:], input.shape[2:]) + self.assertEqual(output.shape[0], input.shape[0]) + self.assertEqual(output.shape[1], 2) + + @skip_if_no_cuda + def test_attentionunet_gpu(self): + for dims in [2, 3]: + shape = (3, 1) + (92,) * dims + input = torch.rand(*shape).to("cuda:0") + model = att.AttentionUnet( + spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), strides=(2, 2) + ).to("cuda:0") + with torch.no_grad(): + output = model(input) + self.assertEqual(output.shape[2:], input.shape[2:]) + self.assertEqual(output.shape[0], input.shape[0]) + self.assertEqual(output.shape[1], 2) + + +if __name__ == "__main__": + unittest.main()