From 81503aacc7af800557cc02abeb26be834646dfe6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Pablo=20de=20la=20Cruz=20Guti=C3=A9rrez?= Date: Tue, 15 Mar 2022 21:33:13 +0100 Subject: [PATCH 1/5] attention unet MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Juan Pablo de la Cruz GutiƩrrez --- docs/source/networks.rst | 7 + monai/networks/nets/__init__.py | 1 + monai/networks/nets/attentionunet.py | 249 +++++++++++++++++++++++++++ tests/test_attentionunet.py | 64 +++++++ 4 files changed, 321 insertions(+) create mode 100644 monai/networks/nets/attentionunet.py create mode 100644 tests/test_attentionunet.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 720a3723dc..25157165aa 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -473,6 +473,13 @@ Nets .. autoclass:: Unet .. autoclass:: unet +`AttentionUnet` +~~~~~~~~~~~~~~~ +.. autoclass:: AttentionUnet + :members: +.. autoclass:: AttentionUNet +.. autoclass:: Attentionunet + `UNETR` ~~~~~~~ .. autoclass:: UNETR diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 22fcef4903..16686fa25c 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -10,6 +10,7 @@ # limitations under the License. from .ahnet import AHnet, Ahnet, AHNet +from .attentionunet import AttentionUnet from .autoencoder import AutoEncoder from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet from .classifier import Classifier, Critic, Discriminator diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py new file mode 100644 index 0000000000..e9f861208f --- /dev/null +++ b/monai/networks/nets/attentionunet.py @@ -0,0 +1,249 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence, Union + +import torch +import torch.nn as nn + +from monai.networks.blocks.convolutions import Convolution +from monai.networks.layers.factories import Norm +from monai.utils import alias, export + +__all__ = ["AttentionUnet"] + + +class ConvBlock(nn.Module): + def __init__( + self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: int = 3, strides: int = 1, dropout=0.0 + ): + super().__init__() + layers = [ + Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + strides=strides, + padding="same" if strides == 1 else None, + adn_ordering="NDA", + act="relu", + norm=Norm.BATCH, + dropout=dropout, + ), + Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + strides=1, + padding="same", + adn_ordering="NDA", + act="relu", + norm=Norm.BATCH, + dropout=dropout, + ), + ] + self.conv = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +class UpConv(nn.Module): + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size=3, strides=2, dropout=0.0): + super().__init__() + self.up = Convolution( + spatial_dims, + in_channels, + out_channels, + strides=strides, + kernel_size=kernel_size, + act="relu", + adn_ordering="NDA", + norm=Norm.BATCH, + dropout=dropout, + is_transposed=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.up(x) + + +class AttentionBlock(nn.Module): + def __init__(self, spatial_dims: int, F_int: int, F_g: int, F_l: int, dropout=0.0): + super().__init__() + self.W_g = nn.Sequential( + Convolution( + spatial_dims=spatial_dims, + in_channels=F_g, + out_channels=F_int, + kernel_size=1, + strides=1, + padding=0, + dropout=dropout, + conv_only=True, + ), + Norm[Norm.BATCH, spatial_dims](F_int), + ) + + self.W_x = nn.Sequential( + Convolution( + spatial_dims=spatial_dims, + in_channels=F_l, + out_channels=F_int, + kernel_size=1, + strides=1, + padding=0, + dropout=dropout, + conv_only=True, + ), + Norm[Norm.BATCH, spatial_dims](F_int), + ) + + self.psi = nn.Sequential( + Convolution( + spatial_dims=spatial_dims, + in_channels=F_int, + out_channels=1, + kernel_size=1, + strides=1, + padding=0, + dropout=dropout, + conv_only=True, + ), + Norm[Norm.BATCH, spatial_dims](1), + nn.Sigmoid(), + ) + + self.relu = nn.ReLU() + + def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + g1 = self.W_g(g) + x1 = self.W_x(x) + psi = self.relu(g1 + x1) + psi = self.psi(psi) + + return x * psi + + +@export("monai.networks.nets") +@alias("Attentionunet") +@alias("AttentionUNet") +class AttentionUnet(nn.Module): + """ + Attention Unet based on + Otkay et al. "Attention U-Net: Learning Where to Look for the Pancreas" + https://arxiv.org/abs/1804.03999 + + Args: + spatial_dims: number of spatial dimensions of the input image. + in_channels: number of the input channel. + out_channels: number of the output classes. + channels (Sequence[int]): sequence of channels. Top block first. The length of `channels` should be no less than 2. + strides (Sequence[int]): stride to use for convolutions. + kernel_size: convolution kernel size. + upsample_kernel_size: convolution kernel size for transposed convolution layers. + dropout: dropout ratio. Defaults to no dropout. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + channels: Sequence[int], + strides: Sequence[int], + kernel_size: Union[Sequence[int], int] = 3, + up_kernel_size: Union[Sequence[int], int] = 3, + dropout: float = 0.0, + ): + super().__init__() + self.dimensions = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.strides = strides + self.kernel_size = kernel_size + self.dropout = dropout + + head = ConvBlock(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=channels[0], dropout=dropout) + reduce_channels = Convolution( + spatial_dims=spatial_dims, + in_channels=channels[0], + out_channels=out_channels, + kernel_size=1, + strides=1, + padding=0, + conv_only=True, + ) + self.up_kernel_size = up_kernel_size + + def _create_block(channels: Sequence[int], strides: Sequence[int], level: int = 0) -> nn.Module: + if len(channels) > 2: + subblock = _create_block(channels[1:], strides[1:], level=level + 1) + return AttentionLayer( + spatial_dims=spatial_dims, + in_channels=channels[0], + out_channels=channels[1], + submodule=nn.Sequential( + ConvBlock( + spatial_dims=spatial_dims, + in_channels=channels[0], + out_channels=channels[1], + strides=strides[0], + dropout=self.dropout, + ), + subblock, + ), + dropout=dropout, + ) + else: + # the next layer is the bottom so stop recursion, + # create the bottom layer as the sublock for this layer + return self._get_bottom_layer(channels[0], channels[1], strides[0], level=level + 1) + + encdec = _create_block(self.channels, self.strides) + self.model = nn.Sequential(head, encdec, reduce_channels) + + def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, level: int) -> nn.Module: + return AttentionLayer( + spatial_dims=self.dimensions, + in_channels=in_channels, + out_channels=out_channels, + submodule=ConvBlock( + spatial_dims=self.dimensions, + in_channels=in_channels, + out_channels=out_channels, + strides=strides, + dropout=self.dropout, + ), + dropout=self.dropout, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model(x) + + +class AttentionLayer(nn.Module): + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, submodule: nn.Module, dropout=0.0): + super().__init__() + self.attention = AttentionBlock(spatial_dims=spatial_dims, F_g=in_channels, F_l=in_channels, F_int=in_channels // 2) + self.upconv = UpConv(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=in_channels, strides=2) + self.merge = Convolution( + spatial_dims=spatial_dims, in_channels=2 * in_channels, out_channels=in_channels, dropout=dropout + ) + self.submodule = submodule + + def forward(self, x: torch.Tensor) -> torch.Tensor: + fromlower = self.upconv(self.submodule(x)) + att = self.attention(g=fromlower, x=x) + return self.merge(torch.cat((att, fromlower), dim=1)) diff --git a/tests/test_attentionunet.py b/tests/test_attentionunet.py new file mode 100644 index 0000000000..ebe94f76c4 --- /dev/null +++ b/tests/test_attentionunet.py @@ -0,0 +1,64 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +import monai.networks.nets.attentionunet as att + + +class TestAttentionUnet(unittest.TestCase): + def test_attention_block(self): + for dims in [2, 3]: + block = att.AttentionBlock(dims, F_int=16, F_g=64, F_l=64) + shape = (4, 64) + (30,) * dims + x = torch.rand(*shape, dtype=torch.float32) + output = block(x, x) + assert output.shape == x.shape + + block = att.AttentionBlock(dims, F_int=16, F_g=32, F_l=64) + xshape = (4, 64) + (30,) * dims + x = torch.rand(*xshape, dtype=torch.float32) + gshape = (4, 32) + (30,) * dims + g = torch.rand(*gshape, dtype=torch.float32) + output = block(g, x) + assert output.shape == x.shape + + def test_attentionunet(self): + for dims in [2, 3]: + shape = (3, 1) + (92,) * dims + input = torch.rand(*shape) + model = att.AttentionUnet( + spatial_dims=dims, in_channels=1, out_channels=2, channels=(16, 32, 64), strides=(2, 2) + ) + output = model(input) + assert output.shape[2:] == input.shape[2:] + assert output.shape[0] == input.shape[0] + assert output.shape[1] == 2 + + def test_attentionunet_gpu(self): + if torch.cuda.is_available(): + for dims in [2, 3]: + shape = (3, 1) + (92,) * dims + input = torch.rand(*shape).to("cuda:0") + model = att.AttentionUnet( + spatial_dims=dims, in_channels=1, out_channels=2, channels=(16, 32, 64), strides=(2, 2) + ).to("cuda:0") + with torch.no_grad(): + output = model(input) + assert output.shape[2:] == input.shape[2:] + assert output.shape[0] == input.shape[0] + assert output.shape[1] == 2 + + +if __name__ == "__main__": + unittest.main() From 960e71994b549595761d04c7083f10451054fdcc Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 15 Mar 2022 21:14:13 +0000 Subject: [PATCH 2/5] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/networks/nets/attentionunet.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py index e9f861208f..7e125996f0 100644 --- a/monai/networks/nets/attentionunet.py +++ b/monai/networks/nets/attentionunet.py @@ -23,7 +23,13 @@ class ConvBlock(nn.Module): def __init__( - self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: int = 3, strides: int = 1, dropout=0.0 + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + strides: int = 1, + dropout=0.0, ): super().__init__() layers = [ @@ -236,7 +242,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AttentionLayer(nn.Module): def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, submodule: nn.Module, dropout=0.0): super().__init__() - self.attention = AttentionBlock(spatial_dims=spatial_dims, F_g=in_channels, F_l=in_channels, F_int=in_channels // 2) + self.attention = AttentionBlock( + spatial_dims=spatial_dims, F_g=in_channels, F_l=in_channels, F_int=in_channels // 2 + ) self.upconv = UpConv(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=in_channels, strides=2) self.merge = Convolution( spatial_dims=spatial_dims, in_channels=2 * in_channels, out_channels=in_channels, dropout=dropout From 415febfdf5c0c2bca3e4d8688d27e41d7727e8ce Mon Sep 17 00:00:00 2001 From: Juan Pablo de la Cruz Gutierrez Date: Wed, 16 Mar 2022 21:21:14 +0100 Subject: [PATCH 3/5] fixed issues and added suggested improvements Signed-off-by: Juan Pablo de la Cruz Gutierrez --- docs/source/networks.rst | 2 -- monai/networks/nets/attentionunet.py | 49 +++++++++++++++------------- 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 25157165aa..7607cd2701 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -477,8 +477,6 @@ Nets ~~~~~~~~~~~~~~~ .. autoclass:: AttentionUnet :members: -.. autoclass:: AttentionUNet -.. autoclass:: Attentionunet `UNETR` ~~~~~~~ diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py index e9f861208f..c0b5bbd5fa 100644 --- a/monai/networks/nets/attentionunet.py +++ b/monai/networks/nets/attentionunet.py @@ -23,7 +23,13 @@ class ConvBlock(nn.Module): def __init__( - self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: int = 3, strides: int = 1, dropout=0.0 + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + strides: int = 1, + dropout=0.0, ): super().__init__() layers = [ @@ -33,7 +39,7 @@ def __init__( out_channels=out_channels, kernel_size=kernel_size, strides=strides, - padding="same" if strides == 1 else None, + padding=None, adn_ordering="NDA", act="relu", norm=Norm.BATCH, @@ -45,7 +51,7 @@ def __init__( out_channels=out_channels, kernel_size=kernel_size, strides=1, - padding="same", + padding=None, adn_ordering="NDA", act="relu", norm=Norm.BATCH, @@ -135,9 +141,24 @@ def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor: return x * psi -@export("monai.networks.nets") -@alias("Attentionunet") -@alias("AttentionUNet") +class AttentionLayer(nn.Module): + def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, submodule: nn.Module, dropout=0.0): + super().__init__() + self.attention = AttentionBlock( + spatial_dims=spatial_dims, F_g=in_channels, F_l=in_channels, F_int=in_channels // 2 + ) + self.upconv = UpConv(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=in_channels, strides=2) + self.merge = Convolution( + spatial_dims=spatial_dims, in_channels=2 * in_channels, out_channels=in_channels, dropout=dropout + ) + self.submodule = submodule + + def forward(self, x: torch.Tensor) -> torch.Tensor: + fromlower = self.upconv(self.submodule(x)) + att = self.attention(g=fromlower, x=x) + return self.merge(torch.cat((att, fromlower), dim=1)) + + class AttentionUnet(nn.Module): """ Attention Unet based on @@ -231,19 +252,3 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, l def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x) - - -class AttentionLayer(nn.Module): - def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, submodule: nn.Module, dropout=0.0): - super().__init__() - self.attention = AttentionBlock(spatial_dims=spatial_dims, F_g=in_channels, F_l=in_channels, F_int=in_channels // 2) - self.upconv = UpConv(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=in_channels, strides=2) - self.merge = Convolution( - spatial_dims=spatial_dims, in_channels=2 * in_channels, out_channels=in_channels, dropout=dropout - ) - self.submodule = submodule - - def forward(self, x: torch.Tensor) -> torch.Tensor: - fromlower = self.upconv(self.submodule(x)) - att = self.attention(g=fromlower, x=x) - return self.merge(torch.cat((att, fromlower), dim=1)) From 02b34506777158b4e45797d41ec8608350471679 Mon Sep 17 00:00:00 2001 From: Juampa Date: Wed, 16 Mar 2022 21:45:55 +0100 Subject: [PATCH 4/5] fixed flake issues --- monai/networks/nets/attentionunet.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py index c0b5bbd5fa..2b1219101a 100644 --- a/monai/networks/nets/attentionunet.py +++ b/monai/networks/nets/attentionunet.py @@ -85,40 +85,40 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AttentionBlock(nn.Module): - def __init__(self, spatial_dims: int, F_int: int, F_g: int, F_l: int, dropout=0.0): + def __init__(self, spatial_dims: int, f_int: int, f_g: int, f_l: int, dropout=0.0): super().__init__() self.W_g = nn.Sequential( Convolution( spatial_dims=spatial_dims, - in_channels=F_g, - out_channels=F_int, + in_channels=f_g, + out_channels=f_int, kernel_size=1, strides=1, padding=0, dropout=dropout, conv_only=True, ), - Norm[Norm.BATCH, spatial_dims](F_int), + Norm[Norm.BATCH, spatial_dims](f_int), ) self.W_x = nn.Sequential( Convolution( spatial_dims=spatial_dims, - in_channels=F_l, - out_channels=F_int, + in_channels=f_l, + out_channels=f_int, kernel_size=1, strides=1, padding=0, dropout=dropout, conv_only=True, ), - Norm[Norm.BATCH, spatial_dims](F_int), + Norm[Norm.BATCH, spatial_dims](f_int), ) self.psi = nn.Sequential( Convolution( spatial_dims=spatial_dims, - in_channels=F_int, + in_channels=f_int, out_channels=1, kernel_size=1, strides=1, @@ -145,7 +145,7 @@ class AttentionLayer(nn.Module): def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, submodule: nn.Module, dropout=0.0): super().__init__() self.attention = AttentionBlock( - spatial_dims=spatial_dims, F_g=in_channels, F_l=in_channels, F_int=in_channels // 2 + spatial_dims=spatial_dims, f_g=in_channels, f_l=in_channels, f_int=in_channels // 2 ) self.upconv = UpConv(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=in_channels, strides=2) self.merge = Convolution( From 8417b93aff4426a50cdf9491e57db5a99b3772aa Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 17 Mar 2022 09:00:35 +0000 Subject: [PATCH 5/5] smaller tests Signed-off-by: Wenqi Li --- tests/test_attentionunet.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_attentionunet.py b/tests/test_attentionunet.py index 7c47ea6e9b..b2f53f9c16 100644 --- a/tests/test_attentionunet.py +++ b/tests/test_attentionunet.py @@ -20,16 +20,16 @@ class TestAttentionUnet(unittest.TestCase): def test_attention_block(self): for dims in [2, 3]: - block = att.AttentionBlock(dims, f_int=16, f_g=64, f_l=64) - shape = (4, 64) + (30,) * dims + block = att.AttentionBlock(dims, f_int=2, f_g=6, f_l=6) + shape = (4, 6) + (30,) * dims x = torch.rand(*shape, dtype=torch.float32) output = block(x, x) self.assertEqual(output.shape, x.shape) - block = att.AttentionBlock(dims, f_int=16, f_g=32, f_l=64) - xshape = (4, 64) + (30,) * dims + block = att.AttentionBlock(dims, f_int=2, f_g=3, f_l=6) + xshape = (4, 6) + (30,) * dims x = torch.rand(*xshape, dtype=torch.float32) - gshape = (4, 32) + (30,) * dims + gshape = (4, 3) + (30,) * dims g = torch.rand(*gshape, dtype=torch.float32) output = block(g, x) self.assertEqual(output.shape, x.shape) @@ -39,7 +39,7 @@ def test_attentionunet(self): shape = (3, 1) + (92,) * dims input = torch.rand(*shape) model = att.AttentionUnet( - spatial_dims=dims, in_channels=1, out_channels=2, channels=(16, 32, 64), strides=(2, 2) + spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), strides=(2, 2) ) output = model(input) self.assertEqual(output.shape[2:], input.shape[2:]) @@ -52,7 +52,7 @@ def test_attentionunet_gpu(self): shape = (3, 1) + (92,) * dims input = torch.rand(*shape).to("cuda:0") model = att.AttentionUnet( - spatial_dims=dims, in_channels=1, out_channels=2, channels=(16, 32, 64), strides=(2, 2) + spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), strides=(2, 2) ).to("cuda:0") with torch.no_grad(): output = model(input)