diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 14fe71728b..796804df24 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -244,7 +244,7 @@ ZoomD, ZoomDict, ) -from .transform import MapTransform, Randomizable, RandomizableTransform, Transform +from .transform import MapTransform, Randomizable, RandomizableTransform, Transform, apply_transform from .utility.array import ( AddChannel, AddExtremePointsChannel, @@ -348,7 +348,7 @@ ToTensorDict, ) from .utils import ( - apply_transform, + allow_missing_keys_mode, copypaste_arrays, create_control_grid, create_grid, diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index a9f66b12a0..21e7da068c 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -18,8 +18,13 @@ import numpy as np # For backwards compatiblity (so this still works: from monai.transforms.compose import MapTransform) -from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform, Transform # noqa: F401 -from monai.transforms.utils import apply_transform +from monai.transforms.transform import ( # noqa: F401 + MapTransform, + Randomizable, + RandomizableTransform, + Transform, + apply_transform, +) from monai.utils import MAX_SEED, ensure_tuple, get_seed __all__ = ["Compose"] diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 7a09efa6d5..2a79b2edf2 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -13,14 +13,39 @@ """ from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, Hashable, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple import numpy as np from monai.config import KeysCollection from monai.utils import MAX_SEED, ensure_tuple -__all__ = ["Randomizable", "RandomizableTransform", "Transform", "MapTransform"] +__all__ = ["apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] + + +def apply_transform(transform: Callable, data, map_items: bool = True): + """ + Transform `data` with `transform`. + If `data` is a list or tuple and `map_data` is True, each item of `data` will be transformed + and this method returns a list of outcomes. + otherwise transform will be applied once with `data` as the argument. + + Args: + transform: a callable to be used to transform `data` + data: an object to be transformed. + map_items: whether to apply transform to each item in `data`, + if `data` is a list or tuple. Defaults to True. + + Raises: + Exception: When ``transform`` raises an exception. + + """ + try: + if isinstance(data, (list, tuple)) and map_items: + return [transform(item) for item in data] + return transform(data) + except Exception as e: + raise RuntimeError(f"applying transform {transform}") from e class Randomizable(ABC): diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 9a84eb00d9..eb1b194c96 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -12,6 +12,7 @@ import itertools import random import warnings +from contextlib import contextmanager from typing import Callable, List, Optional, Sequence, Tuple, Union import numpy as np @@ -19,7 +20,10 @@ from monai.config import DtypeLike, IndexSelection from monai.networks.layers import GaussianFilter +from monai.transforms.compose import Compose +from monai.transforms.transform import MapTransform from monai.utils import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, min_version, optional_import +from monai.utils.misc import issequenceiterable measure, _ = optional_import("skimage.measure", "0.14.2", min_version) @@ -37,7 +41,6 @@ "map_binary_to_indices", "weighted_patch_samples", "generate_pos_neg_label_crop_centers", - "apply_transform", "create_grid", "create_control_grid", "create_rotate", @@ -49,6 +52,7 @@ "get_extreme_points", "extreme_points_to_image", "map_spatial_axes", + "allow_missing_keys_mode", ] @@ -363,31 +367,6 @@ def _correct_centers( return centers -def apply_transform(transform: Callable, data, map_items: bool = True): - """ - Transform `data` with `transform`. - If `data` is a list or tuple and `map_data` is True, each item of `data` will be transformed - and this method returns a list of outcomes. - otherwise transform will be applied once with `data` as the argument. - - Args: - transform: a callable to be used to transform `data` - data: an object to be transformed. - map_items: whether to apply transform to each item in `data`, - if `data` is a list or tuple. Defaults to True. - - Raises: - Exception: When ``transform`` raises an exception. - - """ - try: - if isinstance(data, (list, tuple)) and map_items: - return [transform(item) for item in data] - return transform(data) - except Exception as e: - raise RuntimeError(f"applying transform {transform}") from e - - def create_grid( spatial_size: Sequence[int], spacing: Optional[Sequence[float]] = None, @@ -730,3 +709,50 @@ def map_spatial_axes( spatial_axes_.append(a - 1 if a < 0 else a) return spatial_axes_ + + +@contextmanager +def allow_missing_keys_mode(transform: Union[MapTransform, Compose, Tuple[MapTransform], Tuple[Compose]]): + """Temporarily set all MapTransforms to not throw an error if keys are missing. After, revert to original states. + + Args: + transform: either MapTransform or a Compose + + Example: + + .. code-block:: python + + data = {"image": np.arange(16, dtype=float).reshape(1, 4, 4)} + t = SpatialPadd(["image", "label"], 10, allow_missing_keys=False) + _ = t(data) # would raise exception + with allow_missing_keys_mode(t): + _ = t(data) # OK! + """ + # If given a sequence of transforms, Compose them to get a single list + if issequenceiterable(transform): + transform = Compose(transform) + + # Get list of MapTransforms + transforms = [] + if isinstance(transform, MapTransform): + transforms = [transform] + elif isinstance(transform, Compose): + # Only keep contained MapTransforms + transforms = [t for t in transform.flatten().transforms if isinstance(t, MapTransform)] + if len(transforms) == 0: + raise TypeError( + "allow_missing_keys_mode expects either MapTransform(s) or Compose(s) containing MapTransform(s)" + ) + + # Get the state of each `allow_missing_keys` + orig_states = [t.allow_missing_keys for t in transforms] + + try: + # Set all to True + for t in transforms: + t.allow_missing_keys = True + yield + finally: + # Revert + for t, o_s in zip(transforms, orig_states): + t.allow_missing_keys = o_s diff --git a/tests/test_with_allow_missing_keys.py b/tests/test_with_allow_missing_keys.py new file mode 100644 index 0000000000..68c5ad30c4 --- /dev/null +++ b/tests/test_with_allow_missing_keys.py @@ -0,0 +1,73 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from monai.transforms import Compose, SpatialPad, SpatialPadd, allow_missing_keys_mode + + +class TestWithAllowMissingKeysMode(unittest.TestCase): + def setUp(self): + self.data = {"image": np.arange(16, dtype=float).reshape(1, 4, 4)} + + def test_map_transform(self): + for amk in [True, False]: + t = SpatialPadd(["image", "label"], 10, allow_missing_keys=amk) + with allow_missing_keys_mode(t): + # check state is True + self.assertTrue(t.allow_missing_keys) + # and that transform works even though key is missing + _ = t(self.data) + # check it has returned to original state + self.assertEqual(t.allow_missing_keys, amk) + if not amk: + # should fail because amks==False and key is missing + with self.assertRaises(KeyError): + _ = t(self.data) + + def test_compose(self): + amks = [True, False, True] + t = Compose([SpatialPadd(["image", "label"], 10, allow_missing_keys=amk) for amk in amks]) + with allow_missing_keys_mode(t): + # check states are all True + for _t in t.transforms: + self.assertTrue(_t.allow_missing_keys) + # and that transform works even though key is missing + _ = t(self.data) + # check they've returned to original state + for _t, amk in zip(t.transforms, amks): + self.assertEqual(_t.allow_missing_keys, amk) + # should fail because not all amks==True and key is missing + with self.assertRaises((KeyError, RuntimeError)): + _ = t(self.data) + + def test_array_transform(self): + for t in [SpatialPad(10), Compose([SpatialPad(10)])]: + with self.assertRaises(TypeError): + with allow_missing_keys_mode(t): + pass + + def test_multiple(self): + orig_states = [True, False] + ts = [SpatialPadd(["image", "label"], 10, allow_missing_keys=i) for i in orig_states] + with allow_missing_keys_mode(ts): + for t in ts: + self.assertTrue(t.allow_missing_keys) + # and that transform works even though key is missing + _ = t(self.data) + for t, o_s in zip(ts, orig_states): + self.assertEqual(t.allow_missing_keys, o_s) + + +if __name__ == "__main__": + unittest.main()