diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 3bc8d0899a..dcdeab1ac8 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -38,6 +38,12 @@ Generic Interfaces :members: :special-members: __call__ +`InvertibleTransform` +^^^^^^^^^^^^^^^^^^^^^ +.. autoclass:: InvertibleTransform + :members: + + Vanilla Transforms ------------------ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 796804df24..5b12da4d21 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -138,6 +138,7 @@ ThresholdIntensityD, ThresholdIntensityDict, ) +from .inverse import InvertibleTransform from .io.array import LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .post.array import ( diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 21e7da068c..d509ea33a1 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -17,6 +17,8 @@ import numpy as np +from monai.transforms.inverse import InvertibleTransform + # For backwards compatiblity (so this still works: from monai.transforms.compose import MapTransform) from monai.transforms.transform import ( # noqa: F401 MapTransform, @@ -30,7 +32,7 @@ __all__ = ["Compose"] -class Compose(RandomizableTransform): +class Compose(RandomizableTransform, InvertibleTransform): """ ``Compose`` provides the ability to chain a series of calls together in a sequence. Each transform in the sequence must take a single argument and @@ -141,3 +143,13 @@ def __call__(self, input_): for _transform in self.transforms: input_ = apply_transform(_transform, input_) return input_ + + def inverse(self, data): + invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)] + if len(invertible_transforms) == 0: + warnings.warn("inverse has been called but no invertible transforms have been supplied") + + # loop backwards over transforms + for t in reversed(invertible_transforms): + data = apply_transform(t.inverse, data) + return data diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 823b2dd3f4..667fb7a821 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -15,6 +15,8 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +from copy import deepcopy +from math import floor from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np @@ -30,6 +32,7 @@ SpatialCrop, SpatialPad, ) +from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform from monai.transforms.utils import ( generate_pos_neg_label_crop_centers, @@ -38,6 +41,7 @@ weighted_patch_samples, ) from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple +from monai.utils.enums import InverseKeys __all__ = [ "NumpyPadModeSequence", @@ -82,7 +86,7 @@ NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str] -class SpatialPadd(MapTransform): +class SpatialPadd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SpatialPad`. Performs padding to the data, symmetric for all sides or all on one side for each dimension. @@ -119,9 +123,30 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key, m in self.key_iterator(d, self.mode): + self.push_transform(d, key) d[key] = self.padder(d[key], mode=m) return d + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = transform[InverseKeys.ORIG_SIZE.value] + if self.padder.method == Method.SYMMETRIC: + current_size = d[key].shape[1:] + roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)] + else: + roi_center = [floor(r / 2) if r % 2 == 0 else (r - 1) // 2 for r in orig_size] + + inverse_transform = SpatialCrop(roi_center, orig_size) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + + return d + class BorderPadd(MapTransform): """ diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py new file mode 100644 index 0000000000..f9de8746ca --- /dev/null +++ b/monai/transforms/inverse.py @@ -0,0 +1,113 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Hashable, Optional, Tuple + +import numpy as np + +from monai.transforms.transform import RandomizableTransform, Transform +from monai.utils.enums import InverseKeys + +__all__ = ["InvertibleTransform"] + + +class InvertibleTransform(Transform): + """Classes for invertible transforms. + + This class exists so that an ``invert`` method can be implemented. This allows, for + example, images to be cropped, rotated, padded, etc., during training and inference, + and after be returned to their original size before saving to file for comparison in + an external viewer. + + When the ``__call__`` method is called, the transformation information for each key is + stored. If the transforms were applied to keys "image" and "label", there will be two + extra keys in the dictionary: "image_transforms" and "label_transforms". Each list + contains a list of the transforms applied to that key. When the ``inverse`` method is + called, the inverse is called on each key individually, which allows for different + parameters being passed to each label (e.g., different interpolation for image and + label). + + When the ``inverse`` method is called, the inverse transforms are applied in a last- + in-first-out order. As the inverse is applied, its entry is removed from the list + detailing the applied transformations. That is to say that during the forward pass, + the list of applied transforms grows, and then during the inverse it shrinks back + down to an empty list. + + The information in ``data[key_transform]`` will be compatible with the default collate + since it only stores strings, numbers and arrays. + + We currently check that the ``id()`` of the transform is the same in the forward and + inverse directions. This is a useful check to ensure that the inverses are being + processed in the correct order. However, this may cause issues if the ``id()`` of the + object changes (such as multiprocessing on Windows). If you feel this issue affects + you, please raise a GitHub issue. + + Note to developers: When converting a transform to an invertible transform, you need to: + + #. Inherit from this class. + #. In ``__call__``, add a call to ``push_transform``. + #. Any extra information that might be needed for the inverse can be included with the + dictionary ``extra_info``. This dictionary should have the same keys regardless of + whether ``do_transform`` was `True` or `False` and can only contain objects that are + accepted in pytorch data loader's collate function (e.g., `None` is not allowed). + #. Implement an ``inverse`` method. Make sure that after performing the inverse, + ``pop_transform`` is called. + + """ + + def push_transform( + self, + data: dict, + key: Hashable, + extra_info: Optional[dict] = None, + orig_size: Optional[Tuple] = None, + ) -> None: + """Append to list of applied transforms for that key.""" + key_transform = str(key) + InverseKeys.KEY_SUFFIX.value + info = { + InverseKeys.CLASS_NAME.value: self.__class__.__name__, + InverseKeys.ID.value: id(self), + InverseKeys.ORIG_SIZE.value: orig_size or data[key].shape[1:], + } + if extra_info is not None: + info[InverseKeys.EXTRA_INFO.value] = extra_info + # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) + if isinstance(self, RandomizableTransform): + info[InverseKeys.DO_TRANSFORM.value] = self._do_transform + # If this is the first, create list + if key_transform not in data: + data[key_transform] = [] + data[key_transform].append(info) + + def check_transforms_match(self, transform: dict) -> None: + """Check transforms are of same instance.""" + if transform[InverseKeys.ID.value] != id(self): + raise RuntimeError("Should inverse most recently applied invertible transform first") + + def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: + """Get most recent transform.""" + transform = dict(data[str(key) + InverseKeys.KEY_SUFFIX.value][-1]) + self.check_transforms_match(transform) + return transform + + def pop_transform(self, data: dict, key: Hashable) -> None: + """Remove most recent transform.""" + data[str(key) + InverseKeys.KEY_SUFFIX.value].pop() + + def inverse(self, data: dict) -> Dict[Hashable, np.ndarray]: + """ + Inverse of ``__call__``. + + Raises: + NotImplementedError: When the subclass does not override this method. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 1e17d44029..3c1e7efe24 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -20,6 +20,7 @@ GridSampleMode, GridSamplePadMode, InterpolateMode, + InverseKeys, LossReduction, Method, MetricReduction, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index d1d2d3bcce..d661781616 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -28,6 +28,7 @@ "ChannelMatching", "SkipMode", "Method", + "InverseKeys", ] @@ -214,3 +215,14 @@ class Method(Enum): SYMMETRIC = "symmetric" END = "end" + + +class InverseKeys(Enum): + """Extra meta data keys used for inverse transforms.""" + + CLASS_NAME = "class" + ID = "id" + ORIG_SIZE = "orig_size" + EXTRA_INFO = "extra_info" + DO_TRANSFORM = "do_transforms" + KEY_SUFFIX = "_transforms" diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 5c6f04b48e..24a34482b5 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import unittest +from enum import Enum import numpy as np import torch @@ -20,6 +22,7 @@ from monai.transforms import AddChanneld, Compose, LoadImaged, RandFlipd, SpatialPadd, ToTensord from monai.transforms.post.dictionary import Decollated from monai.utils import optional_import, set_determinism +from monai.utils.enums import InverseKeys from tests.utils import make_nifti_image _, has_nib = optional_import("nibabel") @@ -46,14 +49,20 @@ def tearDown(self) -> None: def check_match(self, in1, in2): if isinstance(in1, dict): self.assertTrue(isinstance(in2, dict)) - self.check_match(list(in1.keys()), list(in2.keys())) - self.check_match(list(in1.values()), list(in2.values())) - elif any(isinstance(in1, i) for i in [list, tuple]): + for (k1, v1), (k2, v2) in zip(in1.items(), in2.items()): + if isinstance(k1, Enum) and isinstance(k2, Enum): + k1, k2 = k1.value, k2.value + self.check_match(k1, k2) + # Transform ids won't match for windows with multiprocessing, so don't check values + if k1 == InverseKeys.ID.value and sys.platform in ["darwin", "win32"]: + continue + self.check_match(v1, v2) + elif isinstance(in1, (list, tuple)): for l1, l2 in zip(in1, in2): self.check_match(l1, l2) - elif any(isinstance(in1, i) for i in [str, int]): + elif isinstance(in1, (str, int)): self.assertEqual(in1, in2) - elif any(isinstance(in1, i) for i in [torch.Tensor, np.ndarray]): + elif isinstance(in1, (torch.Tensor, np.ndarray)): np.testing.assert_array_equal(in1, in2) else: raise RuntimeError(f"Not sure how to compare types. type(in1): {type(in1)}, type(in2): {type(in2)}") diff --git a/tests/test_inverse.py b/tests/test_inverse.py new file mode 100644 index 0000000000..46729c7bc6 --- /dev/null +++ b/tests/test_inverse.py @@ -0,0 +1,218 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest +from typing import TYPE_CHECKING, List, Tuple + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import CacheDataset, DataLoader, create_test_image_2d, create_test_image_3d +from monai.data.utils import decollate_batch +from monai.networks.nets import UNet +from monai.transforms import ( + AddChannel, + AddChanneld, + Compose, + InvertibleTransform, + LoadImaged, + ResizeWithPadOrCrop, + SpatialPadd, + allow_missing_keys_mode, +) +from monai.utils import first, optional_import, set_determinism +from monai.utils.enums import InverseKeys +from tests.utils import make_nifti_image, make_rand_affine + +if TYPE_CHECKING: + + has_nib = True +else: + _, has_nib = optional_import("nibabel") + +KEYS = ["image", "label"] + +TESTS: List[Tuple] = [] + +TESTS.append( + ( + "SpatialPadd (x2) 2d", + "2D", + 0.0, + SpatialPadd(KEYS, spatial_size=[111, 113], method="end"), + SpatialPadd(KEYS, spatial_size=[118, 117]), + ) +) + +TESTS.append( + ( + "SpatialPadd 3d", + "3D", + 0.0, + SpatialPadd(KEYS, spatial_size=[112, 113, 116]), + ) +) + +TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] + +TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore + + +class TestInverse(unittest.TestCase): + """Test inverse methods. + + If tests are failing, the following function might be useful for displaying + `x`, `fx`, `f⁻¹fx` and `x - f⁻¹fx`. + + .. code-block:: python + + def plot_im(orig, fwd_bck, fwd): + import matplotlib.pyplot as plt + diff_orig_fwd_bck = orig - fwd_bck + ims_to_show = [orig, fwd, fwd_bck, diff_orig_fwd_bck] + titles = ["x", "fx", "f⁻¹fx", "x - f⁻¹fx"] + fig, axes = plt.subplots(1, 4, gridspec_kw={"width_ratios": [i.shape[1] for i in ims_to_show]}) + vmin = min(np.array(i).min() for i in [orig, fwd_bck, fwd]) + vmax = max(np.array(i).max() for i in [orig, fwd_bck, fwd]) + for im, title, ax in zip(ims_to_show, titles, axes): + _vmin, _vmax = (vmin, vmax) if id(im) != id(diff_orig_fwd_bck) else (None, None) + im = np.squeeze(np.array(im)) + while im.ndim > 2: + im = im[..., im.shape[-1] // 2] + im_show = ax.imshow(np.squeeze(im), vmin=_vmin, vmax=_vmax) + ax.set_title(title, fontsize=25) + ax.axis("off") + fig.colorbar(im_show, ax=ax) + plt.show() + + This can then be added to the exception: + + .. code-block:: python + + except AssertionError: + print( + f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}" + ) + if orig[0].ndim > 1: + plot_im(orig, fwd_bck, unmodified) + """ + + def setUp(self): + if not has_nib: + self.skipTest("nibabel required for test_inverse") + + set_determinism(seed=0) + + self.all_data = {} + + affine = make_rand_affine() + affine[0] *= 2 + + im_1d = AddChannel()(np.arange(0, 10)) + self.all_data["1D"] = {"image": im_1d, "label": im_1d, "other": im_1d} + + im_2d_fname, seg_2d_fname = [make_nifti_image(i) for i in create_test_image_2d(101, 100)] + im_3d_fname, seg_3d_fname = [make_nifti_image(i, affine) for i in create_test_image_3d(100, 101, 107)] + + load_ims = Compose([LoadImaged(KEYS), AddChanneld(KEYS)]) + self.all_data["2D"] = load_ims({"image": im_2d_fname, "label": seg_2d_fname}) + self.all_data["3D"] = load_ims({"image": im_3d_fname, "label": seg_3d_fname}) + + def tearDown(self): + set_determinism(seed=None) + + def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): + for key in keys: + orig = orig_d[key] + fwd_bck = fwd_bck_d[key] + if isinstance(fwd_bck, torch.Tensor): + fwd_bck = fwd_bck.cpu().numpy() + unmodified = unmodified_d[key] + if isinstance(orig, np.ndarray): + mean_diff = np.mean(np.abs(orig - fwd_bck)) + unmodded_diff = np.mean(np.abs(orig - ResizeWithPadOrCrop(orig.shape[1:])(unmodified))) + try: + self.assertLessEqual(mean_diff, acceptable_diff) + except AssertionError: + print( + f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}" + ) + raise + + @parameterized.expand(TESTS) + def test_inverse(self, _, data_name, acceptable_diff, *transforms): + name = _ + + data = self.all_data[data_name] + + forwards = [data.copy()] + + # Apply forwards + for t in transforms: + forwards.append(t(forwards[-1])) + + # Check that error is thrown when inverse are used out of order. + t = SpatialPadd("image", [10, 5]) + with self.assertRaises(RuntimeError): + t.inverse(forwards[-1]) + + # Apply inverses + fwd_bck = forwards[-1].copy() + for i, t in enumerate(reversed(transforms)): + if isinstance(t, InvertibleTransform): + fwd_bck = t.inverse(fwd_bck) + self.check_inverse(name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff) + + def test_inverse_inferred_seg(self): + + test_data = [] + for _ in range(20): + image, label = create_test_image_2d(100, 101) + test_data.append({"image": image, "label": label.astype(np.float32)}) + + batch_size = 10 + # num workers = 0 for mac + num_workers = 2 if sys.platform != "darwin" else 0 + transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153))]) + num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) + + dataset = CacheDataset(test_data, transform=transforms, progress=False) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + device = "cuda" if torch.cuda.is_available() else "cpu" + model = UNet( + dimensions=2, + in_channels=1, + out_channels=1, + channels=(2, 4), + strides=(2,), + ).to(device) + + data = first(loader) + labels = data["label"].to(device) + segs = model(labels).detach().cpu() + label_transform_key = "label" + InverseKeys.KEY_SUFFIX.value + segs_dict = {"label": segs, label_transform_key: data[label_transform_key]} + segs_dict_decollated = decollate_batch(segs_dict) + + # inverse of individual segmentation + seg_dict = first(segs_dict_decollated) + with allow_missing_keys_mode(transforms): + inv_seg = transforms.inverse(seg_dict)["label"] + self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) + self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms) + self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) + + +if __name__ == "__main__": + unittest.main()