Skip to content
Merged
55 changes: 17 additions & 38 deletions monai/networks/nets/hovernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
# =========================================================================

from collections import OrderedDict
from enum import Enum
from typing import Callable, Dict, List, Optional, Sequence, Type, Union

import torch
Expand All @@ -37,8 +36,8 @@
from monai.networks.blocks import UpSample
from monai.networks.layers.factories import Conv, Dropout
from monai.networks.layers.utils import get_act_layer, get_norm_layer
from monai.utils import InterpolateMode, UpsampleMode, export
from monai.utils.enums import StrEnum
from monai.utils.enums import HoVerNetBranch, HoVerNetMode, InterpolateMode, UpsampleMode
from monai.utils.module import export, look_up_option

__all__ = ["HoVerNet", "Hovernet", "HoVernet", "HoVerNet"]

Expand Down Expand Up @@ -380,40 +379,21 @@ class HoVerNet(nn.Module):
Medical Image Analysis 2019

Args:
mode: use original implementation (`HoVerNetMODE.ORIGINAL` or "original") or
a faster implementation (`HoVerNetMODE.FAST` or "fast"). Defaults to `HoVerNetMODE.FAST`.
in_channels: number of the input channel.
out_classes: number of the nuclear type classes.
act: activation type and arguments. Defaults to relu.
norm: feature normalization type and arguments. Defaults to batch norm.
dropout_prob: dropout rate after each dense layer.
"""

Comment thread
bhashemian marked this conversation as resolved.
class Mode(Enum):
FAST: int = 0
ORIGINAL: int = 1

class Branch(StrEnum):
"""
Three branches of HoVerNet model, which results in three outputs:
`HOVER` is horizontal and vertical regressed gradient map of each nucleus,
`NUCLEUS` is the segmentation of all nuclei, and
`TYPE` is the type of each nucleus.

"""

HV = "horizontal_vertical"
NP = "nucleus_prediction"
NC = "type_prediction"

def _mode_to_int(self, mode) -> int:

if mode == self.Mode.FAST:
return 0
else:
return 1
Mode = HoVerNetMode
Branch = HoVerNetBranch

def __init__(
self,
mode: Mode = Mode.FAST,
mode: Union[HoVerNetMode, str] = HoVerNetMode.FAST,
Comment thread
bhashemian marked this conversation as resolved.
in_channels: int = 3,
out_classes: int = 0,
act: Union[str, tuple] = ("relu", {"inplace": True}),
Expand All @@ -423,10 +403,9 @@ def __init__(

super().__init__()

self.mode: int = self._mode_to_int(mode)

if mode not in [self.Mode.ORIGINAL, self.Mode.FAST]:
raise ValueError("Input size should be 270 x 270 when using Mode.ORIGINAL")
if isinstance(mode, str):
mode = mode.upper()
self.mode = look_up_option(mode, HoVerNetMode)

if out_classes > 128:
raise ValueError("Number of nuclear types classes exceeds maximum (128)")
Expand All @@ -441,7 +420,7 @@ def __init__(
# number of layers in each pooling block.
_block_config: Sequence[int] = (3, 4, 6, 3)

if mode == self.Mode.FAST:
if self.mode == HoVerNetMode.FAST:
_ksize = 3
_pad = 3
else:
Expand Down Expand Up @@ -510,12 +489,12 @@ def __init__(

def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:

if self.mode == 1:
if self.mode == HoVerNetMode.ORIGINAL.value:
Comment thread
wyli marked this conversation as resolved.
if x.shape[-1] != 270 or x.shape[-2] != 270:
raise ValueError("Input size should be 270 x 270 when using Mode.ORIGINAL")
raise ValueError("Input size should be 270 x 270 when using HoVerNetMode.ORIGINAL")
else:
if x.shape[-1] != 256 or x.shape[-2] != 256:
raise ValueError("Input size should be 256 x 256 when using Mode.FAST")
raise ValueError("Input size should be 256 x 256 when using HoVerNetMode.FAST")

x = x / 255.0 # to 0-1 range to match XY
x = self.input_features(x)
Expand All @@ -531,11 +510,11 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
x = self.upsample(x)

output = {
HoVerNet.Branch.NP.value: self.nucleus_prediction(x, short_cuts),
HoVerNet.Branch.HV.value: self.horizontal_vertical(x, short_cuts),
HoVerNetBranch.NP.value: self.nucleus_prediction(x, short_cuts),
HoVerNetBranch.HV.value: self.horizontal_vertical(x, short_cuts),
}
if self.type_prediction is not None:
output[HoVerNet.Branch.NC.value] = self.type_prediction(x, short_cuts)
output[HoVerNetBranch.NC.value] = self.type_prediction(x, short_cuts)

return output

Expand Down
2 changes: 2 additions & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
GridPatchSort,
GridSampleMode,
GridSamplePadMode,
HoVerNetBranch,
HoVerNetMode,
InterpolateMode,
InverseKeys,
JITMetadataKeys,
Expand Down
27 changes: 27 additions & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
"ImageStatsKeys",
"LabelStatsKeys",
"AlgoEnsembleKeys",
"HoVerNetMode",
"HoVerNetBranch",
]


Expand Down Expand Up @@ -587,3 +589,28 @@ class AlgoEnsembleKeys(StrEnum):
ID = "identifier"
ALGO = "infer_algo"
SCORE = "best_metric"


class HoVerNetMode(StrEnum):
"""
Modes for HoVerNet model:
`FAST`: a faster implementation (than original)
`ORIGINAL`: the original implementation
"""

FAST = "FAST"
ORIGINAL = "ORIGINAL"


class HoVerNetBranch(StrEnum):
"""
Three branches of HoVerNet model, which results in three outputs:
`HV` is horizontal and vertical gradient map of each nucleus (regression),
`NP` is the pixel prediction of all nuclei (segmentation), and
`NC` is the type of each nucleus (classification).

"""

HV = "horizontal_vertical"
NP = "nucleus_prediction"
NC = "type_prediction"
1 change: 1 addition & 0 deletions tests/test_hovernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

ILL_CASES = [
[{"out_classes": 6, "mode": 3}],
[{"out_classes": 6, "mode": "Wrong"}],
[{"out_classes": 1000, "mode": HoVerNet.Mode.ORIGINAL}],
[{"out_classes": 1, "mode": HoVerNet.Mode.ORIGINAL}],
[{"out_classes": 6, "mode": HoVerNet.Mode.ORIGINAL, "dropout_prob": 100}],
Expand Down