diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 21259936e7..25ce61ab3a 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -158,7 +158,7 @@ def __init__( def _create_block( inc: int, outc: int, channels: Sequence[int], strides: Sequence[int], is_top: bool - ) -> nn.Sequential: + ) -> nn.Module: """ Builds the UNet structure from the bottom up by recursing down to the bottom block, then creating sequential blocks containing the downsample path, a skip connection around the previous block, and the upsample path. @@ -186,12 +186,29 @@ def _create_block( down = self._get_down_layer(inc, c, s, is_top) # create layer in downsampling path up = self._get_up_layer(upc, outc, s, is_top) # create layer in upsampling path - return nn.Sequential(down, SkipConnection(subblock), up) + return self._get_connection_block(down, up, subblock) self.model = _create_block(in_channels, out_channels, self.channels, self.strides, True) + def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: + """ + Returns the block object defining a layer of the UNet structure including the implementation of the skip + between encoding (down) and and decoding (up) sides of the network. + + Args: + down_path: encoding half of the layer + up_path: decoding half of the layer + subblock: block defining the next layer in the network. + Returns: block for this layer: `nn.Sequential(down_path, SkipConnection(subblock), up_path)` + """ + return nn.Sequential(down_path, SkipConnection(subblock), up_path) + def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module: """ + Returns the encoding (down) part of a layer of the network. This typically will downsample data at some point + in its structure. Its output is used as input to the next layer down and is concatenated with output from the + next layer to form the input for the decode (up) part of the layer. + Args: in_channels: number of input channels. out_channels: number of output channels. @@ -229,6 +246,8 @@ def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_ def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module: """ + Returns the bottom or bottleneck layer at the bottom of the network linking encode to decode halves. + Args: in_channels: number of input channels. out_channels: number of output channels. @@ -237,6 +256,9 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module: def _get_up_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module: """ + Returns the decoding (up) part of a layer of the network. This typically will upsample data at some point + in its structure. Its output is used as input to the next layer up. + Args: in_channels: number of input channels. out_channels: number of output channels.