Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monai/data/grid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from monai.data.dataset import Dataset
from monai.data.utils import iter_patch
from monai.transforms import apply_transform
from monai.utils import NumpyPadMode, ensure_tuple
from monai.utils import NumpyPadMode, ensure_tuple, look_up_option

__all__ = ["PatchDataset", "GridPatchDataset", "PatchIter"]

Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(
"""
self.patch_size = (None,) + tuple(patch_size)
self.start_pos = ensure_tuple(start_pos)
self.mode: NumpyPadMode = NumpyPadMode(mode)
self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode)
self.pad_opts = pad_opts

def __call__(self, array):
Expand Down
4 changes: 2 additions & 2 deletions monai/data/png_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from monai.data.png_writer import write_png
from monai.data.utils import create_file_basename
from monai.utils import ImageMetaKey as Key
from monai.utils import InterpolateMode
from monai.utils import InterpolateMode, look_up_option
from monai.utils.enums import DataObjects


Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(
self.output_postfix = output_postfix
self.output_ext = output_ext
self.resample = resample
self.mode: InterpolateMode = InterpolateMode(mode)
self.mode: InterpolateMode = look_up_option(mode, InterpolateMode)
self.scale = scale
self.data_root_dir = data_root_dir
self.separate_folder = separate_folder
Expand Down
4 changes: 2 additions & 2 deletions monai/data/png_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch

from monai.transforms.spatial.array import Resize
from monai.utils import InterpolateMode, ensure_tuple_rep, optional_import
from monai.utils import InterpolateMode, ensure_tuple_rep, look_up_option, optional_import
from monai.utils.enums import DataObjects
from monai.utils.misc import convert_data_type

Expand Down Expand Up @@ -58,7 +58,7 @@ def write_png(
data_np = data_np.squeeze(2)
if output_spatial_shape is not None:
output_spatial_shape_ = ensure_tuple_rep(output_spatial_shape, 2)
mode = InterpolateMode(mode)
mode = look_up_option(mode, InterpolateMode)
align_corners = None if mode in (InterpolateMode.NEAREST, InterpolateMode.AREA) else False
xform = Resize(spatial_size=output_spatial_shape_, mode=mode, align_corners=align_corners)
_min, _max = np.min(data_np), np.max(data_np)
Expand Down
5 changes: 3 additions & 2 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
fall_back_tuple,
first,
issequenceiterable,
look_up_option,
optional_import,
)
from monai.utils.enums import DataObjects, Method
Expand Down Expand Up @@ -217,7 +218,7 @@ def iter_patch(
start_pos = ensure_tuple_size(start_pos, arr.ndim)

# pad image by maximum values needed to ensure patches are taken from inside an image
arrpad = np.pad(arr, tuple((p, p) for p in patch_size_), NumpyPadMode(mode).value, **pad_opts)
arrpad = np.pad(arr, tuple((p, p) for p in patch_size_), look_up_option(mode, NumpyPadMode).value, **pad_opts)

# choose a start position in the padded image
start_pos_padded = tuple(s + p for s, p in zip(start_pos, patch_size_))
Expand Down Expand Up @@ -770,7 +771,7 @@ def compute_importance_map(
Tensor of size patch_size.

"""
mode = BlendMode(mode)
mode = look_up_option(mode, BlendMode)
device = torch.device(device) # type: ignore[arg-type]
if mode == BlendMode.CONSTANT:
importance_map = torch.ones(patch_size, device=device).float()
Expand Down
3 changes: 2 additions & 1 deletion monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from monai.transforms import Transform
from monai.utils import ForwardMode, ensure_tuple, min_version, optional_import
from monai.utils.enums import CommonKeys as Keys
from monai.utils.module import look_up_option

if TYPE_CHECKING:
from ignite.engine import Engine, EventEnum
Expand Down Expand Up @@ -109,7 +110,7 @@ def __init__(
event_to_attr=event_to_attr,
decollate=decollate,
)
mode = ForwardMode(mode)
self.mode = look_up_option(mode, ForwardMode)
if mode == ForwardMode.EVAL:
self.mode = eval_mode
elif mode == ForwardMode.TRAIN:
Expand Down
5 changes: 3 additions & 2 deletions monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch

from monai.config import IgniteInfo, KeysCollection
from monai.utils import deprecated, ensure_tuple, get_torch_version_tuple, min_version, optional_import
from monai.utils import deprecated, ensure_tuple, get_torch_version_tuple, look_up_option, min_version, optional_import
from monai.utils.enums import DataObjects

idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed")
Expand Down Expand Up @@ -214,7 +214,8 @@ class mean median max 5percentile 95percentile notnans

def _compute_op(op: str, d: np.ndarray):
if not op.endswith("percentile"):
return supported_ops[op](d)
c_op = look_up_option(op, supported_ops)
return c_op(d)

threshold = int(op.split("percentile")[0])
return supported_ops["90percentile"]((d, threshold))
Expand Down
4 changes: 2 additions & 2 deletions monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch.nn.functional as F

from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size
from monai.utils import BlendMode, PytorchPadMode, fall_back_tuple
from monai.utils import BlendMode, PytorchPadMode, fall_back_tuple, look_up_option

__all__ = ["sliding_window_inference"]

Expand Down Expand Up @@ -103,7 +103,7 @@ def sliding_window_inference(
diff = max(roi_size[k - 2] - inputs.shape[k], 0)
half = diff // 2
pad_size.extend([half, diff - half])
inputs = F.pad(inputs, pad=pad_size, mode=PytorchPadMode(padding_mode).value, value=cval)
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval)

scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)

Expand Down
4 changes: 2 additions & 2 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from monai.losses.focal_loss import FocalLoss
from monai.losses.spatial_mask import MaskedLoss
from monai.networks import one_hot
from monai.utils import LossReduction, Weight
from monai.utils import LossReduction, Weight, look_up_option
from monai.utils.enums import DataObjects


Expand Down Expand Up @@ -267,7 +267,7 @@ def __init__(
self.softmax = softmax
self.other_act = other_act

self.w_type = Weight(w_type)
self.w_type = look_up_option(w_type, Weight)

self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
Expand Down
5 changes: 5 additions & 0 deletions monai/metrics/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ def compute_hausdorff_distance(
hd = np.empty((batch_size, n_class))
for b, c in np.ndindex(batch_size, n_class):
(edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c])
if not np.any(edges_gt):
warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.")
if not np.any(edges_pred):
warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.")

distance_1 = compute_percent_hausdorff_distance(edges_pred, edges_gt, distance_metric, percentile)
if directed:
hd[b, c] = distance_1
Expand Down
4 changes: 2 additions & 2 deletions monai/metrics/rocauc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np
import torch

from monai.utils import Average
from monai.utils import Average, look_up_option

from .metric import CumulativeIterationMetric

Expand Down Expand Up @@ -146,7 +146,7 @@ def compute_roc_auc(
if y.shape != y_pred.shape:
raise AssertionError("data shapes of y_pred and y do not match.")

average = Average(average)
average = look_up_option(average, Average)
if average == Average.MICRO:
return _calculate(y_pred.flatten(), y.flatten())
y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1)
Expand Down
5 changes: 5 additions & 0 deletions monai/metrics/surface_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ def compute_average_surface_distance(

for b, c in np.ndindex(batch_size, n_class):
(edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c])
if not np.any(edges_gt):
warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.")
if not np.any(edges_pred):
warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.")

surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric)
if surface_distance.shape == (0,):
avg_surface_distance = np.nan
Expand Down
11 changes: 6 additions & 5 deletions monai/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Tuple, Union

import numpy as np
import torch

from monai.transforms.croppad.array import SpatialCrop
from monai.transforms.utils import generate_spatial_bounding_box
from monai.utils import MetricReduction, optional_import
from monai.utils import MetricReduction, look_up_option, optional_import
from monai.utils.enums import DataObjects

binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion")
Expand Down Expand Up @@ -71,7 +70,7 @@ def do_metric_reduction(
not_nans = (~nans).float()

t_zero = torch.zeros(1, device=f.device, dtype=f.dtype)
reduction = MetricReduction(reduction)
reduction = look_up_option(reduction, MetricReduction)
if reduction == MetricReduction.NONE:
return f, not_nans

Expand Down Expand Up @@ -189,15 +188,17 @@ def get_surface_distance(
- ``"euclidean"``, uses Exact Euclidean distance transform.
- ``"chessboard"``, uses `chessboard` metric in chamfer type of transform.
- ``"taxicab"``, uses `taxicab` metric in chamfer type of transform.

Note:
If seg_pred or seg_gt is all 0, may result in nan/inf distance.

"""

if not np.any(seg_gt):
dis = np.inf * np.ones_like(seg_gt)
warnings.warn("ground truth is all 0, this may result in nan/inf distance.")
else:
if not np.any(seg_pred):
dis = np.inf * np.ones_like(seg_gt)
warnings.warn("prediction is all 0, this may result in nan/inf distance.")
return np.asarray(dis[seg_gt])
if distance_metric == "euclidean":
dis = distance_transform_edt(~seg_gt)
Expand Down
4 changes: 2 additions & 2 deletions monai/networks/blocks/upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from monai.networks.layers.factories import Conv, Pad, Pool
from monai.networks.utils import icnr_init, pixelshuffle
from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep
from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option

__all__ = ["Upsample", "UpSample", "SubpixelUpsample", "Subpixelupsample", "SubpixelUpSample"]

Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(
"""
super().__init__()
scale_factor_ = ensure_tuple_rep(scale_factor, dimensions)
up_mode = UpsampleMode(mode)
up_mode = look_up_option(mode, UpsampleMode)
if up_mode == UpsampleMode.DECONV:
if not in_channels:
raise ValueError(f"in_channels needs to be specified in the '{mode}' mode.")
Expand Down
6 changes: 4 additions & 2 deletions monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def use_factory(fact_args):

import torch.nn as nn

from monai.utils import look_up_option

__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"]


Expand Down Expand Up @@ -120,8 +122,8 @@ def get_constructor(self, factory_name: str, *args) -> Any:
if not isinstance(factory_name, str):
raise TypeError(f"factory_name must a str but is {type(factory_name).__name__}.")

fact = self.factories[factory_name.upper()]
return fact(*args)
func = look_up_option(factory_name.upper(), self.factories)
return func(*args)

def __getitem__(self, args) -> Any:
"""
Expand Down
5 changes: 3 additions & 2 deletions monai/networks/layers/simplelayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
InvalidPyTorchVersionError,
SkipMode,
ensure_tuple_rep,
look_up_option,
optional_import,
)

Expand Down Expand Up @@ -75,7 +76,7 @@ def __init__(
self.pad = None
if in_channels == out_channels:
return
mode = ChannelMatching(mode)
mode = look_up_option(mode, ChannelMatching)
if mode == ChannelMatching.PROJECT:
conv_type = Conv[Conv.CONV, spatial_dims]
self.project = conv_type(in_channels, out_channels, kernel_size=1)
Expand Down Expand Up @@ -119,7 +120,7 @@ def __init__(self, submodule, dim: int = 1, mode: Union[str, SkipMode] = "cat")
super().__init__()
self.submodule = submodule
self.dim = dim
self.mode = SkipMode(mode).value
self.mode = look_up_option(mode, SkipMode).value

def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.submodule(x)
Expand Down
6 changes: 3 additions & 3 deletions monai/networks/layers/spatial_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch.nn as nn

from monai.networks import to_norm_affine
from monai.utils import GridSampleMode, GridSamplePadMode, ensure_tuple, optional_import
from monai.utils import GridSampleMode, GridSamplePadMode, ensure_tuple, look_up_option, optional_import

_C, _ = optional_import("monai._C")

Expand Down Expand Up @@ -455,8 +455,8 @@ def __init__(
super().__init__()
self.spatial_size = ensure_tuple(spatial_size) if spatial_size is not None else None
self.normalized = normalized
self.mode: GridSampleMode = GridSampleMode(mode)
self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode)
self.mode: GridSampleMode = look_up_option(mode, GridSampleMode)
self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode)
self.align_corners = align_corners
self.reverse_indexing = reverse_indexing

Expand Down
Loading