From 552bc42099763a82448ff9e118e04a59c62979d9 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 5 Mar 2021 18:03:15 +0000 Subject: [PATCH 1/3] with allow_missing_keys_mode Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 4 +- monai/transforms/compose.py | 3 +- monai/transforms/transform.py | 29 +++++++++++- monai/transforms/utils.py | 67 +++++++++++++++++---------- tests/test_with_allow_missing_keys.py | 58 +++++++++++++++++++++++ 5 files changed, 130 insertions(+), 31 deletions(-) create mode 100644 tests/test_with_allow_missing_keys.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 14fe71728b..402ab819ba 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -244,7 +244,7 @@ ZoomD, ZoomDict, ) -from .transform import MapTransform, Randomizable, RandomizableTransform, Transform +from .transform import apply_transform, MapTransform, Randomizable, RandomizableTransform, Transform from .utility.array import ( AddChannel, AddExtremePointsChannel, @@ -348,7 +348,6 @@ ToTensorDict, ) from .utils import ( - apply_transform, copypaste_arrays, create_control_grid, create_grid, @@ -373,4 +372,5 @@ resize_center, weighted_patch_samples, zero_margins, + allow_missing_keys_mode, ) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index a9f66b12a0..8d49361d50 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -18,8 +18,7 @@ import numpy as np # For backwards compatiblity (so this still works: from monai.transforms.compose import MapTransform) -from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform, Transform # noqa: F401 -from monai.transforms.utils import apply_transform +from monai.transforms.transform import apply_transform, MapTransform, Randomizable, RandomizableTransform, Transform # noqa: F401 from monai.utils import MAX_SEED, ensure_tuple, get_seed __all__ = ["Compose"] diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 7a09efa6d5..02b3ee6c71 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -13,16 +13,41 @@ """ from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, Hashable, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple, Union import numpy as np from monai.config import KeysCollection from monai.utils import MAX_SEED, ensure_tuple -__all__ = ["Randomizable", "RandomizableTransform", "Transform", "MapTransform"] +__all__ = ["apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] + +def apply_transform(transform: Callable, data, map_items: bool = True): + """ + Transform `data` with `transform`. + If `data` is a list or tuple and `map_data` is True, each item of `data` will be transformed + and this method returns a list of outcomes. + otherwise transform will be applied once with `data` as the argument. + + Args: + transform: a callable to be used to transform `data` + data: an object to be transformed. + map_items: whether to apply transform to each item in `data`, + if `data` is a list or tuple. Defaults to True. + + Raises: + Exception: When ``transform`` raises an exception. + + """ + try: + if isinstance(data, (list, tuple)) and map_items: + return [transform(item) for item in data] + return transform(data) + except Exception as e: + raise RuntimeError(f"applying transform {transform}") from e + class Randomizable(ABC): """ An interface for handling random state locally, currently based on a class variable `R`, diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 9a84eb00d9..e9a3d3b303 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -9,7 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager import itertools +from monai.transforms.transform import MapTransform +from monai.transforms.compose import Compose import random import warnings from typing import Callable, List, Optional, Sequence, Tuple, Union @@ -49,6 +52,7 @@ "get_extreme_points", "extreme_points_to_image", "map_spatial_axes", + "allow_missing_keys_mode", ] @@ -363,31 +367,6 @@ def _correct_centers( return centers -def apply_transform(transform: Callable, data, map_items: bool = True): - """ - Transform `data` with `transform`. - If `data` is a list or tuple and `map_data` is True, each item of `data` will be transformed - and this method returns a list of outcomes. - otherwise transform will be applied once with `data` as the argument. - - Args: - transform: a callable to be used to transform `data` - data: an object to be transformed. - map_items: whether to apply transform to each item in `data`, - if `data` is a list or tuple. Defaults to True. - - Raises: - Exception: When ``transform`` raises an exception. - - """ - try: - if isinstance(data, (list, tuple)) and map_items: - return [transform(item) for item in data] - return transform(data) - except Exception as e: - raise RuntimeError(f"applying transform {transform}") from e - - def create_grid( spatial_size: Sequence[int], spacing: Optional[Sequence[float]] = None, @@ -730,3 +709,41 @@ def map_spatial_axes( spatial_axes_.append(a - 1 if a < 0 else a) return spatial_axes_ + +@contextmanager +def allow_missing_keys_mode(transform: Union[MapTransform, Compose]): + """Temporarily set all MapTransforms to not throw an error if keys are missing. After, revert to original states. + + Args: + transform: either MapTransform or a Compose + + Example: + + .. code-block:: python + + data = {"image": np.arange(16, dtype=float).reshape(1, 4, 4)} + t = SpatialPadd(["image", "label"], 10, allow_missing_keys=False) + _ = t(data) # would raise exception + with allow_missing_keys_mode(t): + _ = t(data) # OK! + """ + if isinstance(transform, MapTransform): + transforms = [transform] + elif isinstance(transform, Compose): + # Only keep contained MapTransforms + transforms = [t for t in transform.flatten().transforms if isinstance(t, MapTransform)] + else: + transforms = [] + + # Get the state of each `allow_missing_keys` + orig_states = [t.allow_missing_keys for t in transforms] + + try: + # Set all to True + for t in transforms: + t.allow_missing_keys = True + yield + finally: + # Revert + for t, o_s in zip(transforms, orig_states): + t.allow_missing_keys = o_s diff --git a/tests/test_with_allow_missing_keys.py b/tests/test_with_allow_missing_keys.py new file mode 100644 index 0000000000..93753ac1bb --- /dev/null +++ b/tests/test_with_allow_missing_keys.py @@ -0,0 +1,58 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import unittest +from monai.transforms import Compose, SpatialPadd, SpatialPad, allow_missing_keys_mode + +class TestWithAllowMissingKeysMode(unittest.TestCase): + def setUp(self): + self.data = {"image": np.arange(16, dtype=float).reshape(1, 4, 4)} + + def test_map_transform(self): + for amk in [True, False]: + t = SpatialPadd(["image", "label"], 10, allow_missing_keys=amk) + with allow_missing_keys_mode(t): + # check state is True + self.assertTrue(t.allow_missing_keys) + # and that transform works even though key is missing + _ = t(self.data) + # check it has returned to original state + self.assertEqual(t.allow_missing_keys, amk) + if not amk: + # should fail because amks==False and key is missing + with self.assertRaises(KeyError): + _ = t(self.data) + + def test_compose(self): + amks = [True, False, True] + t = Compose([SpatialPadd(["image", "label"], 10, allow_missing_keys=amk) for amk in amks]) + with allow_missing_keys_mode(t): + # check states are all True + for _t in t.transforms: + self.assertTrue(_t.allow_missing_keys) + # and that transform works even though key is missing + _ = t(self.data) + # check they've returned to original state + for _t, amk in zip(t.transforms, amks): + self.assertEqual(_t.allow_missing_keys, amk) + # should fail because not all amks==True and key is missing + with self.assertRaises((KeyError, RuntimeError)): + _ = t(self.data) + + def test_array_transform(self): + for t in [SpatialPad(10), Compose([SpatialPad(10)])]: + with allow_missing_keys_mode(t): + # should work as nothing should have changed + _ = t(self.data["image"]) + +if __name__ == "__main__": + unittest.main() From d169253e3bbfd96e05318d64a8b1ff019990cacf Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 8 Mar 2021 10:56:11 +0000 Subject: [PATCH 2/3] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 4 ++-- monai/transforms/compose.py | 8 +++++++- monai/transforms/transform.py | 4 ++-- monai/transforms/utils.py | 8 ++++---- tests/test_with_allow_missing_keys.py | 8 ++++++-- 5 files changed, 21 insertions(+), 11 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 402ab819ba..796804df24 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -244,7 +244,7 @@ ZoomD, ZoomDict, ) -from .transform import apply_transform, MapTransform, Randomizable, RandomizableTransform, Transform +from .transform import MapTransform, Randomizable, RandomizableTransform, Transform, apply_transform from .utility.array import ( AddChannel, AddExtremePointsChannel, @@ -348,6 +348,7 @@ ToTensorDict, ) from .utils import ( + allow_missing_keys_mode, copypaste_arrays, create_control_grid, create_grid, @@ -372,5 +373,4 @@ resize_center, weighted_patch_samples, zero_margins, - allow_missing_keys_mode, ) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 8d49361d50..21e7da068c 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -18,7 +18,13 @@ import numpy as np # For backwards compatiblity (so this still works: from monai.transforms.compose import MapTransform) -from monai.transforms.transform import apply_transform, MapTransform, Randomizable, RandomizableTransform, Transform # noqa: F401 +from monai.transforms.transform import ( # noqa: F401 + MapTransform, + Randomizable, + RandomizableTransform, + Transform, + apply_transform, +) from monai.utils import MAX_SEED, ensure_tuple, get_seed __all__ = ["Compose"] diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 02b3ee6c71..2a79b2edf2 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -13,7 +13,7 @@ """ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple import numpy as np @@ -23,7 +23,6 @@ __all__ = ["apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] - def apply_transform(transform: Callable, data, map_items: bool = True): """ Transform `data` with `transform`. @@ -48,6 +47,7 @@ def apply_transform(transform: Callable, data, map_items: bool = True): except Exception as e: raise RuntimeError(f"applying transform {transform}") from e + class Randomizable(ABC): """ An interface for handling random state locally, currently based on a class variable `R`, diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e9a3d3b303..4ade9a4e9f 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -9,12 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import contextmanager import itertools -from monai.transforms.transform import MapTransform -from monai.transforms.compose import Compose import random import warnings +from contextlib import contextmanager from typing import Callable, List, Optional, Sequence, Tuple, Union import numpy as np @@ -22,6 +20,8 @@ from monai.config import DtypeLike, IndexSelection from monai.networks.layers import GaussianFilter +from monai.transforms.compose import Compose +from monai.transforms.transform import MapTransform from monai.utils import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, min_version, optional_import measure, _ = optional_import("skimage.measure", "0.14.2", min_version) @@ -40,7 +40,6 @@ "map_binary_to_indices", "weighted_patch_samples", "generate_pos_neg_label_crop_centers", - "apply_transform", "create_grid", "create_control_grid", "create_rotate", @@ -710,6 +709,7 @@ def map_spatial_axes( return spatial_axes_ + @contextmanager def allow_missing_keys_mode(transform: Union[MapTransform, Compose]): """Temporarily set all MapTransforms to not throw an error if keys are missing. After, revert to original states. diff --git a/tests/test_with_allow_missing_keys.py b/tests/test_with_allow_missing_keys.py index 93753ac1bb..c644d1d9c4 100644 --- a/tests/test_with_allow_missing_keys.py +++ b/tests/test_with_allow_missing_keys.py @@ -9,9 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np import unittest -from monai.transforms import Compose, SpatialPadd, SpatialPad, allow_missing_keys_mode + +import numpy as np + +from monai.transforms import Compose, SpatialPad, SpatialPadd, allow_missing_keys_mode + class TestWithAllowMissingKeysMode(unittest.TestCase): def setUp(self): @@ -54,5 +57,6 @@ def test_array_transform(self): # should work as nothing should have changed _ = t(self.data["image"]) + if __name__ == "__main__": unittest.main() From 5015c9b67dfaec04a1e7ec3a95a05a61e81258c6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 8 Mar 2021 11:33:41 +0000 Subject: [PATCH 3/3] allow iterable of transforms Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utils.py | 15 ++++++++++++--- tests/test_with_allow_missing_keys.py | 17 ++++++++++++++--- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 4ade9a4e9f..eb1b194c96 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -23,6 +23,7 @@ from monai.transforms.compose import Compose from monai.transforms.transform import MapTransform from monai.utils import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, min_version, optional_import +from monai.utils.misc import issequenceiterable measure, _ = optional_import("skimage.measure", "0.14.2", min_version) @@ -711,7 +712,7 @@ def map_spatial_axes( @contextmanager -def allow_missing_keys_mode(transform: Union[MapTransform, Compose]): +def allow_missing_keys_mode(transform: Union[MapTransform, Compose, Tuple[MapTransform], Tuple[Compose]]): """Temporarily set all MapTransforms to not throw an error if keys are missing. After, revert to original states. Args: @@ -727,13 +728,21 @@ def allow_missing_keys_mode(transform: Union[MapTransform, Compose]): with allow_missing_keys_mode(t): _ = t(data) # OK! """ + # If given a sequence of transforms, Compose them to get a single list + if issequenceiterable(transform): + transform = Compose(transform) + + # Get list of MapTransforms + transforms = [] if isinstance(transform, MapTransform): transforms = [transform] elif isinstance(transform, Compose): # Only keep contained MapTransforms transforms = [t for t in transform.flatten().transforms if isinstance(t, MapTransform)] - else: - transforms = [] + if len(transforms) == 0: + raise TypeError( + "allow_missing_keys_mode expects either MapTransform(s) or Compose(s) containing MapTransform(s)" + ) # Get the state of each `allow_missing_keys` orig_states = [t.allow_missing_keys for t in transforms] diff --git a/tests/test_with_allow_missing_keys.py b/tests/test_with_allow_missing_keys.py index c644d1d9c4..68c5ad30c4 100644 --- a/tests/test_with_allow_missing_keys.py +++ b/tests/test_with_allow_missing_keys.py @@ -53,9 +53,20 @@ def test_compose(self): def test_array_transform(self): for t in [SpatialPad(10), Compose([SpatialPad(10)])]: - with allow_missing_keys_mode(t): - # should work as nothing should have changed - _ = t(self.data["image"]) + with self.assertRaises(TypeError): + with allow_missing_keys_mode(t): + pass + + def test_multiple(self): + orig_states = [True, False] + ts = [SpatialPadd(["image", "label"], 10, allow_missing_keys=i) for i in orig_states] + with allow_missing_keys_mode(ts): + for t in ts: + self.assertTrue(t.allow_missing_keys) + # and that transform works even though key is missing + _ = t(self.data) + for t, o_s in zip(ts, orig_states): + self.assertEqual(t.allow_missing_keys, o_s) if __name__ == "__main__":