Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,11 @@ Nets
.. autoclass:: Unet
.. autoclass:: unet

`AttentionUnet`
~~~~~~~~~~~~~~~
.. autoclass:: AttentionUnet
:members:

`UNETR`
~~~~~~~
.. autoclass:: UNETR
Expand Down
1 change: 1 addition & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

from .ahnet import AHnet, Ahnet, AHNet
from .attentionunet import AttentionUnet
from .autoencoder import AutoEncoder
from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet
from .classifier import Classifier, Critic, Discriminator
Expand Down
257 changes: 257 additions & 0 deletions monai/networks/nets/attentionunet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Sequence, Union

import torch
import torch.nn as nn

from monai.networks.blocks.convolutions import Convolution
from monai.networks.layers.factories import Norm

__all__ = ["AttentionUnet"]


class ConvBlock(nn.Module):
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
strides: int = 1,
dropout=0.0,
):
super().__init__()
layers = [
Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
strides=strides,
padding=None,
adn_ordering="NDA",
act="relu",
norm=Norm.BATCH,
dropout=dropout,
),
Convolution(
spatial_dims=spatial_dims,
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
strides=1,
padding=None,
adn_ordering="NDA",
act="relu",
norm=Norm.BATCH,
dropout=dropout,
),
]
self.conv = nn.Sequential(*layers)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_c: torch.Tensor = self.conv(x)
return x_c


class UpConv(nn.Module):
def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size=3, strides=2, dropout=0.0):
super().__init__()
self.up = Convolution(
spatial_dims,
in_channels,
out_channels,
strides=strides,
kernel_size=kernel_size,
act="relu",
adn_ordering="NDA",
norm=Norm.BATCH,
dropout=dropout,
is_transposed=True,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_u: torch.Tensor = self.up(x)
return x_u


class AttentionBlock(nn.Module):
def __init__(self, spatial_dims: int, f_int: int, f_g: int, f_l: int, dropout=0.0):
super().__init__()
self.W_g = nn.Sequential(
Convolution(
spatial_dims=spatial_dims,
in_channels=f_g,
out_channels=f_int,
kernel_size=1,
strides=1,
padding=0,
dropout=dropout,
conv_only=True,
),
Norm[Norm.BATCH, spatial_dims](f_int),
)

self.W_x = nn.Sequential(
Convolution(
spatial_dims=spatial_dims,
in_channels=f_l,
out_channels=f_int,
kernel_size=1,
strides=1,
padding=0,
dropout=dropout,
conv_only=True,
),
Norm[Norm.BATCH, spatial_dims](f_int),
)

self.psi = nn.Sequential(
Convolution(
spatial_dims=spatial_dims,
in_channels=f_int,
out_channels=1,
kernel_size=1,
strides=1,
padding=0,
dropout=dropout,
conv_only=True,
),
Norm[Norm.BATCH, spatial_dims](1),
nn.Sigmoid(),
)

self.relu = nn.ReLU()

def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
g1 = self.W_g(g)
x1 = self.W_x(x)
psi: torch.Tensor = self.relu(g1 + x1)
psi = self.psi(psi)

return x * psi


class AttentionLayer(nn.Module):
def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, submodule: nn.Module, dropout=0.0):
super().__init__()
self.attention = AttentionBlock(
spatial_dims=spatial_dims, f_g=in_channels, f_l=in_channels, f_int=in_channels // 2
)
self.upconv = UpConv(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=in_channels, strides=2)
self.merge = Convolution(
spatial_dims=spatial_dims, in_channels=2 * in_channels, out_channels=in_channels, dropout=dropout
)
self.submodule = submodule

def forward(self, x: torch.Tensor) -> torch.Tensor:
fromlower = self.upconv(self.submodule(x))
att = self.attention(g=fromlower, x=x)
att_m: torch.Tensor = self.merge(torch.cat((att, fromlower), dim=1))
return att_m


class AttentionUnet(nn.Module):
Comment thread
wyli marked this conversation as resolved.
"""
Attention Unet based on
Otkay et al. "Attention U-Net: Learning Where to Look for the Pancreas"
https://arxiv.org/abs/1804.03999

Args:
spatial_dims: number of spatial dimensions of the input image.
in_channels: number of the input channel.
out_channels: number of the output classes.
channels (Sequence[int]): sequence of channels. Top block first. The length of `channels` should be no less than 2.
strides (Sequence[int]): stride to use for convolutions.
kernel_size: convolution kernel size.
upsample_kernel_size: convolution kernel size for transposed convolution layers.
dropout: dropout ratio. Defaults to no dropout.
"""

def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
channels: Sequence[int],
strides: Sequence[int],
kernel_size: Union[Sequence[int], int] = 3,
up_kernel_size: Union[Sequence[int], int] = 3,
dropout: float = 0.0,
):
super().__init__()
self.dimensions = spatial_dims
self.in_channels = in_channels
self.out_channels = out_channels
self.channels = channels
self.strides = strides
self.kernel_size = kernel_size
self.dropout = dropout

head = ConvBlock(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=channels[0], dropout=dropout)
reduce_channels = Convolution(
spatial_dims=spatial_dims,
in_channels=channels[0],
out_channels=out_channels,
kernel_size=1,
strides=1,
padding=0,
conv_only=True,
)
self.up_kernel_size = up_kernel_size

def _create_block(channels: Sequence[int], strides: Sequence[int], level: int = 0) -> nn.Module:
if len(channels) > 2:
subblock = _create_block(channels[1:], strides[1:], level=level + 1)
return AttentionLayer(
spatial_dims=spatial_dims,
in_channels=channels[0],
out_channels=channels[1],
submodule=nn.Sequential(
ConvBlock(
spatial_dims=spatial_dims,
in_channels=channels[0],
out_channels=channels[1],
strides=strides[0],
dropout=self.dropout,
),
subblock,
),
dropout=dropout,
)
else:
# the next layer is the bottom so stop recursion,
# create the bottom layer as the sublock for this layer
return self._get_bottom_layer(channels[0], channels[1], strides[0], level=level + 1)

encdec = _create_block(self.channels, self.strides)
self.model = nn.Sequential(head, encdec, reduce_channels)

def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, level: int) -> nn.Module:
return AttentionLayer(
spatial_dims=self.dimensions,
in_channels=in_channels,
out_channels=out_channels,
submodule=ConvBlock(
spatial_dims=self.dimensions,
in_channels=in_channels,
out_channels=out_channels,
strides=strides,
dropout=self.dropout,
),
dropout=self.dropout,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_m: torch.Tensor = self.model(x)
return x_m
65 changes: 65 additions & 0 deletions tests/test_attentionunet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

import monai.networks.nets.attentionunet as att
from tests.utils import skip_if_no_cuda


class TestAttentionUnet(unittest.TestCase):
def test_attention_block(self):
for dims in [2, 3]:
block = att.AttentionBlock(dims, f_int=2, f_g=6, f_l=6)
shape = (4, 6) + (30,) * dims
x = torch.rand(*shape, dtype=torch.float32)
output = block(x, x)
self.assertEqual(output.shape, x.shape)

block = att.AttentionBlock(dims, f_int=2, f_g=3, f_l=6)
xshape = (4, 6) + (30,) * dims
x = torch.rand(*xshape, dtype=torch.float32)
gshape = (4, 3) + (30,) * dims
g = torch.rand(*gshape, dtype=torch.float32)
output = block(g, x)
self.assertEqual(output.shape, x.shape)

def test_attentionunet(self):
for dims in [2, 3]:
shape = (3, 1) + (92,) * dims
input = torch.rand(*shape)
model = att.AttentionUnet(
spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), strides=(2, 2)
)
output = model(input)
self.assertEqual(output.shape[2:], input.shape[2:])
self.assertEqual(output.shape[0], input.shape[0])
self.assertEqual(output.shape[1], 2)

@skip_if_no_cuda
def test_attentionunet_gpu(self):
for dims in [2, 3]:
shape = (3, 1) + (92,) * dims
input = torch.rand(*shape).to("cuda:0")
model = att.AttentionUnet(
spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), strides=(2, 2)
).to("cuda:0")
with torch.no_grad():
output = model(input)
self.assertEqual(output.shape[2:], input.shape[2:])
self.assertEqual(output.shape[0], input.shape[0])
self.assertEqual(output.shape[1], 2)


if __name__ == "__main__":
unittest.main()