diff --git a/monai/networks/nets/hovernet.py b/monai/networks/nets/hovernet.py index d7c876c848..c024463348 100644 --- a/monai/networks/nets/hovernet.py +++ b/monai/networks/nets/hovernet.py @@ -28,7 +28,6 @@ # ========================================================================= from collections import OrderedDict -from enum import Enum from typing import Callable, Dict, List, Optional, Sequence, Type, Union import torch @@ -37,8 +36,8 @@ from monai.networks.blocks import UpSample from monai.networks.layers.factories import Conv, Dropout from monai.networks.layers.utils import get_act_layer, get_norm_layer -from monai.utils import InterpolateMode, UpsampleMode, export -from monai.utils.enums import StrEnum +from monai.utils.enums import HoVerNetBranch, HoVerNetMode, InterpolateMode, UpsampleMode +from monai.utils.module import export, look_up_option __all__ = ["HoVerNet", "Hovernet", "HoVernet", "HoVerNet"] @@ -380,6 +379,8 @@ class HoVerNet(nn.Module): Medical Image Analysis 2019 Args: + mode: use original implementation (`HoVerNetMODE.ORIGINAL` or "original") or + a faster implementation (`HoVerNetMODE.FAST` or "fast"). Defaults to `HoVerNetMODE.FAST`. in_channels: number of the input channel. out_classes: number of the nuclear type classes. act: activation type and arguments. Defaults to relu. @@ -387,33 +388,12 @@ class HoVerNet(nn.Module): dropout_prob: dropout rate after each dense layer. """ - class Mode(Enum): - FAST: int = 0 - ORIGINAL: int = 1 - - class Branch(StrEnum): - """ - Three branches of HoVerNet model, which results in three outputs: - `HOVER` is horizontal and vertical regressed gradient map of each nucleus, - `NUCLEUS` is the segmentation of all nuclei, and - `TYPE` is the type of each nucleus. - - """ - - HV = "horizontal_vertical" - NP = "nucleus_prediction" - NC = "type_prediction" - - def _mode_to_int(self, mode) -> int: - - if mode == self.Mode.FAST: - return 0 - else: - return 1 + Mode = HoVerNetMode + Branch = HoVerNetBranch def __init__( self, - mode: Mode = Mode.FAST, + mode: Union[HoVerNetMode, str] = HoVerNetMode.FAST, in_channels: int = 3, out_classes: int = 0, act: Union[str, tuple] = ("relu", {"inplace": True}), @@ -423,10 +403,9 @@ def __init__( super().__init__() - self.mode: int = self._mode_to_int(mode) - - if mode not in [self.Mode.ORIGINAL, self.Mode.FAST]: - raise ValueError("Input size should be 270 x 270 when using Mode.ORIGINAL") + if isinstance(mode, str): + mode = mode.upper() + self.mode = look_up_option(mode, HoVerNetMode) if out_classes > 128: raise ValueError("Number of nuclear types classes exceeds maximum (128)") @@ -441,7 +420,7 @@ def __init__( # number of layers in each pooling block. _block_config: Sequence[int] = (3, 4, 6, 3) - if mode == self.Mode.FAST: + if self.mode == HoVerNetMode.FAST: _ksize = 3 _pad = 3 else: @@ -510,12 +489,12 @@ def __init__( def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: - if self.mode == 1: + if self.mode == HoVerNetMode.ORIGINAL.value: if x.shape[-1] != 270 or x.shape[-2] != 270: - raise ValueError("Input size should be 270 x 270 when using Mode.ORIGINAL") + raise ValueError("Input size should be 270 x 270 when using HoVerNetMode.ORIGINAL") else: if x.shape[-1] != 256 or x.shape[-2] != 256: - raise ValueError("Input size should be 256 x 256 when using Mode.FAST") + raise ValueError("Input size should be 256 x 256 when using HoVerNetMode.FAST") x = x / 255.0 # to 0-1 range to match XY x = self.input_features(x) @@ -531,11 +510,11 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: x = self.upsample(x) output = { - HoVerNet.Branch.NP.value: self.nucleus_prediction(x, short_cuts), - HoVerNet.Branch.HV.value: self.horizontal_vertical(x, short_cuts), + HoVerNetBranch.NP.value: self.nucleus_prediction(x, short_cuts), + HoVerNetBranch.HV.value: self.horizontal_vertical(x, short_cuts), } if self.type_prediction is not None: - output[HoVerNet.Branch.NC.value] = self.type_prediction(x, short_cuts) + output[HoVerNetBranch.NC.value] = self.type_prediction(x, short_cuts) return output diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 2428da88a2..8eccac8f70 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -29,6 +29,8 @@ GridPatchSort, GridSampleMode, GridSamplePadMode, + HoVerNetBranch, + HoVerNetMode, InterpolateMode, InverseKeys, JITMetadataKeys, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index d69c184dae..12e82cd378 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -52,6 +52,8 @@ "ImageStatsKeys", "LabelStatsKeys", "AlgoEnsembleKeys", + "HoVerNetMode", + "HoVerNetBranch", ] @@ -587,3 +589,28 @@ class AlgoEnsembleKeys(StrEnum): ID = "identifier" ALGO = "infer_algo" SCORE = "best_metric" + + +class HoVerNetMode(StrEnum): + """ + Modes for HoVerNet model: + `FAST`: a faster implementation (than original) + `ORIGINAL`: the original implementation + """ + + FAST = "FAST" + ORIGINAL = "ORIGINAL" + + +class HoVerNetBranch(StrEnum): + """ + Three branches of HoVerNet model, which results in three outputs: + `HV` is horizontal and vertical gradient map of each nucleus (regression), + `NP` is the pixel prediction of all nuclei (segmentation), and + `NC` is the type of each nucleus (classification). + + """ + + HV = "horizontal_vertical" + NP = "nucleus_prediction" + NC = "type_prediction" diff --git a/tests/test_hovernet.py b/tests/test_hovernet.py index 45a6bb55b9..2365210f55 100644 --- a/tests/test_hovernet.py +++ b/tests/test_hovernet.py @@ -54,6 +54,7 @@ ILL_CASES = [ [{"out_classes": 6, "mode": 3}], + [{"out_classes": 6, "mode": "Wrong"}], [{"out_classes": 1000, "mode": HoVerNet.Mode.ORIGINAL}], [{"out_classes": 1, "mode": HoVerNet.Mode.ORIGINAL}], [{"out_classes": 6, "mode": HoVerNet.Mode.ORIGINAL, "dropout_prob": 100}],