Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@
ZoomD,
ZoomDict,
)
from .transform import MapTransform, Randomizable, RandomizableTransform, Transform
from .transform import MapTransform, Randomizable, RandomizableTransform, Transform, apply_transform
from .utility.array import (
AddChannel,
AddExtremePointsChannel,
Expand Down Expand Up @@ -348,7 +348,7 @@
ToTensorDict,
)
from .utils import (
apply_transform,
allow_missing_keys_mode,
copypaste_arrays,
create_control_grid,
create_grid,
Expand Down
9 changes: 7 additions & 2 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
import numpy as np

# For backwards compatiblity (so this still works: from monai.transforms.compose import MapTransform)
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform, Transform # noqa: F401
from monai.transforms.utils import apply_transform
from monai.transforms.transform import ( # noqa: F401
MapTransform,
Randomizable,
RandomizableTransform,
Transform,
apply_transform,
)
from monai.utils import MAX_SEED, ensure_tuple, get_seed

__all__ = ["Compose"]
Expand Down
29 changes: 27 additions & 2 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,39 @@
"""

from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, Hashable, Iterable, List, Optional, Tuple
from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple

import numpy as np

from monai.config import KeysCollection
from monai.utils import MAX_SEED, ensure_tuple

__all__ = ["Randomizable", "RandomizableTransform", "Transform", "MapTransform"]
__all__ = ["apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"]


def apply_transform(transform: Callable, data, map_items: bool = True):
"""
Transform `data` with `transform`.
If `data` is a list or tuple and `map_data` is True, each item of `data` will be transformed
and this method returns a list of outcomes.
otherwise transform will be applied once with `data` as the argument.

Args:
transform: a callable to be used to transform `data`
data: an object to be transformed.
map_items: whether to apply transform to each item in `data`,
if `data` is a list or tuple. Defaults to True.

Raises:
Exception: When ``transform`` raises an exception.

"""
try:
if isinstance(data, (list, tuple)) and map_items:
return [transform(item) for item in data]
return transform(data)
except Exception as e:
raise RuntimeError(f"applying transform {transform}") from e


class Randomizable(ABC):
Expand Down
78 changes: 52 additions & 26 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@
import itertools
import random
import warnings
from contextlib import contextmanager
from typing import Callable, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch

from monai.config import DtypeLike, IndexSelection
from monai.networks.layers import GaussianFilter
from monai.transforms.compose import Compose
from monai.transforms.transform import MapTransform
from monai.utils import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, min_version, optional_import
from monai.utils.misc import issequenceiterable

measure, _ = optional_import("skimage.measure", "0.14.2", min_version)

Expand All @@ -37,7 +41,6 @@
"map_binary_to_indices",
"weighted_patch_samples",
"generate_pos_neg_label_crop_centers",
"apply_transform",
"create_grid",
"create_control_grid",
"create_rotate",
Expand All @@ -49,6 +52,7 @@
"get_extreme_points",
"extreme_points_to_image",
"map_spatial_axes",
"allow_missing_keys_mode",
]


Expand Down Expand Up @@ -363,31 +367,6 @@ def _correct_centers(
return centers


def apply_transform(transform: Callable, data, map_items: bool = True):
"""
Transform `data` with `transform`.
If `data` is a list or tuple and `map_data` is True, each item of `data` will be transformed
and this method returns a list of outcomes.
otherwise transform will be applied once with `data` as the argument.

Args:
transform: a callable to be used to transform `data`
data: an object to be transformed.
map_items: whether to apply transform to each item in `data`,
if `data` is a list or tuple. Defaults to True.

Raises:
Exception: When ``transform`` raises an exception.

"""
try:
if isinstance(data, (list, tuple)) and map_items:
return [transform(item) for item in data]
return transform(data)
except Exception as e:
raise RuntimeError(f"applying transform {transform}") from e


def create_grid(
spatial_size: Sequence[int],
spacing: Optional[Sequence[float]] = None,
Expand Down Expand Up @@ -730,3 +709,50 @@ def map_spatial_axes(
spatial_axes_.append(a - 1 if a < 0 else a)

return spatial_axes_


@contextmanager
def allow_missing_keys_mode(transform: Union[MapTransform, Compose, Tuple[MapTransform], Tuple[Compose]]):
"""Temporarily set all MapTransforms to not throw an error if keys are missing. After, revert to original states.

Args:
transform: either MapTransform or a Compose

Example:

.. code-block:: python

data = {"image": np.arange(16, dtype=float).reshape(1, 4, 4)}
t = SpatialPadd(["image", "label"], 10, allow_missing_keys=False)
_ = t(data) # would raise exception
with allow_missing_keys_mode(t):
_ = t(data) # OK!
"""
# If given a sequence of transforms, Compose them to get a single list
if issequenceiterable(transform):
transform = Compose(transform)

# Get list of MapTransforms
transforms = []
if isinstance(transform, MapTransform):
transforms = [transform]
Comment thread
Nic-Ma marked this conversation as resolved.
elif isinstance(transform, Compose):
# Only keep contained MapTransforms
transforms = [t for t in transform.flatten().transforms if isinstance(t, MapTransform)]
if len(transforms) == 0:
raise TypeError(
"allow_missing_keys_mode expects either MapTransform(s) or Compose(s) containing MapTransform(s)"
)

# Get the state of each `allow_missing_keys`
orig_states = [t.allow_missing_keys for t in transforms]

try:
# Set all to True
for t in transforms:
t.allow_missing_keys = True
yield
finally:
# Revert
for t, o_s in zip(transforms, orig_states):
t.allow_missing_keys = o_s
73 changes: 73 additions & 0 deletions tests/test_with_allow_missing_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

from monai.transforms import Compose, SpatialPad, SpatialPadd, allow_missing_keys_mode


class TestWithAllowMissingKeysMode(unittest.TestCase):
def setUp(self):
self.data = {"image": np.arange(16, dtype=float).reshape(1, 4, 4)}

def test_map_transform(self):
for amk in [True, False]:
t = SpatialPadd(["image", "label"], 10, allow_missing_keys=amk)
with allow_missing_keys_mode(t):
# check state is True
self.assertTrue(t.allow_missing_keys)
# and that transform works even though key is missing
_ = t(self.data)
# check it has returned to original state
self.assertEqual(t.allow_missing_keys, amk)
if not amk:
# should fail because amks==False and key is missing
with self.assertRaises(KeyError):
_ = t(self.data)

def test_compose(self):
amks = [True, False, True]
t = Compose([SpatialPadd(["image", "label"], 10, allow_missing_keys=amk) for amk in amks])
with allow_missing_keys_mode(t):
# check states are all True
for _t in t.transforms:
self.assertTrue(_t.allow_missing_keys)
# and that transform works even though key is missing
_ = t(self.data)
# check they've returned to original state
for _t, amk in zip(t.transforms, amks):
self.assertEqual(_t.allow_missing_keys, amk)
# should fail because not all amks==True and key is missing
with self.assertRaises((KeyError, RuntimeError)):
_ = t(self.data)

def test_array_transform(self):
for t in [SpatialPad(10), Compose([SpatialPad(10)])]:
with self.assertRaises(TypeError):
with allow_missing_keys_mode(t):
pass

def test_multiple(self):
orig_states = [True, False]
ts = [SpatialPadd(["image", "label"], 10, allow_missing_keys=i) for i in orig_states]
with allow_missing_keys_mode(ts):
for t in ts:
self.assertTrue(t.allow_missing_keys)
# and that transform works even though key is missing
_ = t(self.data)
for t, o_s in zip(ts, orig_states):
self.assertEqual(t.allow_missing_keys, o_s)


if __name__ == "__main__":
unittest.main()