Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions monai/networks/nets/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init__(

def _create_block(
inc: int, outc: int, channels: Sequence[int], strides: Sequence[int], is_top: bool
) -> nn.Sequential:
) -> nn.Module:
"""
Builds the UNet structure from the bottom up by recursing down to the bottom block, then creating sequential
blocks containing the downsample path, a skip connection around the previous block, and the upsample path.
Expand Down Expand Up @@ -186,12 +186,29 @@ def _create_block(
down = self._get_down_layer(inc, c, s, is_top) # create layer in downsampling path
up = self._get_up_layer(upc, outc, s, is_top) # create layer in upsampling path

return nn.Sequential(down, SkipConnection(subblock), up)
return self._get_connection_block(down, up, subblock)

self.model = _create_block(in_channels, out_channels, self.channels, self.strides, True)

def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
"""
Returns the block object defining a layer of the UNet structure including the implementation of the skip
between encoding (down) and and decoding (up) sides of the network.

Args:
down_path: encoding half of the layer
up_path: decoding half of the layer
subblock: block defining the next layer in the network.
Returns: block for this layer: `nn.Sequential(down_path, SkipConnection(subblock), up_path)`
"""
return nn.Sequential(down_path, SkipConnection(subblock), up_path)

def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module:
"""
Returns the encoding (down) part of a layer of the network. This typically will downsample data at some point
in its structure. Its output is used as input to the next layer down and is concatenated with output from the
next layer to form the input for the decode (up) part of the layer.

Args:
in_channels: number of input channels.
out_channels: number of output channels.
Expand Down Expand Up @@ -229,6 +246,8 @@ def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_

def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module:
"""
Returns the bottom or bottleneck layer at the bottom of the network linking encode to decode halves.

Args:
in_channels: number of input channels.
out_channels: number of output channels.
Expand All @@ -237,6 +256,9 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module:

def _get_up_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module:
"""
Returns the decoding (up) part of a layer of the network. This typically will upsample data at some point
in its structure. Its output is used as input to the next layer up.

Args:
in_channels: number of input channels.
out_channels: number of output channels.
Expand Down