From 1acaaa7fb397c4f62f93b801ddd9b1f3479cd7c0 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 22 Mar 2022 11:40:11 +0000 Subject: [PATCH 1/3] Minor change to UNet to permit greater modifiability Signed-off-by: Eric Kerfoot --- monai/networks/nets/unet.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 21259936e7..4d658dcc70 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -177,7 +177,7 @@ def _create_block( if len(channels) > 2: subblock = _create_block(c, c, channels[1:], strides[1:], False) # continue recursion down - upc = c * 2 + upc = c * s else: # the next layer is the bottom so stop recursion, create the bottom layer as the sublock for this layer subblock = self._get_bottom_layer(c, channels[1]) @@ -186,12 +186,29 @@ def _create_block( down = self._get_down_layer(inc, c, s, is_top) # create layer in downsampling path up = self._get_up_layer(upc, outc, s, is_top) # create layer in upsampling path - return nn.Sequential(down, SkipConnection(subblock), up) + return self._get_connection_block(down, up, subblock) self.model = _create_block(in_channels, out_channels, self.channels, self.strides, True) + def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: + """ + Returns the block object defining a layer of the UNet structure including the implementation of the skip + between encoding (down) and and decoding (up) sides of the network. + + Args: + down_path: encoding half of the layer + up_path: decoding half of the layer + subblock: block defining the next layer in the network. + Returns: block for this layer: `nn.Sequential(down_path, SkipConnection(subblock), up_path)` + """ + return nn.Sequential(down_path, SkipConnection(subblock), up_path) + def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module: """ + Returns the encoding (down) part of a layer of the network. This typically will downsample data at some point + in its structure. Its output is used as input to the next layer down and is concatenated with output from the + next layer to form the input for the decode (up) part of the layer. + Args: in_channels: number of input channels. out_channels: number of output channels. @@ -229,6 +246,8 @@ def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_ def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module: """ + Returns the bottom or bottleneck layer at the bottom of the network linking encode to decode halves. + Args: in_channels: number of input channels. out_channels: number of output channels. @@ -237,6 +256,9 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module: def _get_up_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module: """ + Returns the decoding (up) part of a layer of the network. This typically will upsample data at some point + in its structure. Its output is used as input to the next layer up. + Args: in_channels: number of input channels. out_channels: number of output channels. From 84f55cecefa13f97080682b47960604c9e710f05 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 22 Mar 2022 12:20:36 +0000 Subject: [PATCH 2/3] Type fix Signed-off-by: Eric Kerfoot --- monai/networks/nets/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 4d658dcc70..e31868f01e 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -158,7 +158,7 @@ def __init__( def _create_block( inc: int, outc: int, channels: Sequence[int], strides: Sequence[int], is_top: bool - ) -> nn.Sequential: + ) -> nn.Module: """ Builds the UNet structure from the bottom up by recursing down to the bottom block, then creating sequential blocks containing the downsample path, a skip connection around the previous block, and the upsample path. From 1ba0645481afb95b10a7194f8fed82ce05b921bc Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 22 Mar 2022 16:23:35 +0000 Subject: [PATCH 3/3] Reverting minor change Signed-off-by: Eric Kerfoot --- monai/networks/nets/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index e31868f01e..25ce61ab3a 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -177,7 +177,7 @@ def _create_block( if len(channels) > 2: subblock = _create_block(c, c, channels[1:], strides[1:], False) # continue recursion down - upc = c * s + upc = c * 2 else: # the next layer is the bottom so stop recursion, create the bottom layer as the sublock for this layer subblock = self._get_bottom_layer(c, channels[1])