Skip to content
6 changes: 6 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ Generic Interfaces
:members:
:special-members: __call__

`InvertibleTransform`
^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: InvertibleTransform
:members:


Vanilla Transforms
------------------

Expand Down
1 change: 1 addition & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
ThresholdIntensityD,
ThresholdIntensityDict,
)
from .inverse import InvertibleTransform
from .io.array import LoadImage, SaveImage
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
from .post.array import (
Expand Down
14 changes: 13 additions & 1 deletion monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import numpy as np

from monai.transforms.inverse import InvertibleTransform

# For backwards compatiblity (so this still works: from monai.transforms.compose import MapTransform)
from monai.transforms.transform import ( # noqa: F401
MapTransform,
Expand All @@ -30,7 +32,7 @@
__all__ = ["Compose"]


class Compose(RandomizableTransform):
class Compose(RandomizableTransform, InvertibleTransform):
"""
``Compose`` provides the ability to chain a series of calls together in a
sequence. Each transform in the sequence must take a single argument and
Expand Down Expand Up @@ -141,3 +143,13 @@ def __call__(self, input_):
for _transform in self.transforms:
input_ = apply_transform(_transform, input_)
return input_

def inverse(self, data):
invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)]
if len(invertible_transforms) == 0:
warnings.warn("inverse has been called but no invertible transforms have been supplied")

# loop backwards over transforms
for t in reversed(invertible_transforms):
data = apply_transform(t.inverse, data)
return data
27 changes: 26 additions & 1 deletion monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
Class names are ended with 'd' to denote dictionary-based transforms.
"""

from copy import deepcopy
from math import floor
from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
Expand All @@ -30,6 +32,7 @@
SpatialCrop,
SpatialPad,
)
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform
from monai.transforms.utils import (
generate_pos_neg_label_crop_centers,
Expand All @@ -38,6 +41,7 @@
weighted_patch_samples,
)
from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple
from monai.utils.enums import InverseKeys

__all__ = [
"NumpyPadModeSequence",
Expand Down Expand Up @@ -82,7 +86,7 @@
NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str]


class SpatialPadd(MapTransform):
class SpatialPadd(MapTransform, InvertibleTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.SpatialPad`.
Performs padding to the data, symmetric for all sides or all on one side for each dimension.
Expand Down Expand Up @@ -119,9 +123,30 @@ def __init__(
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = dict(data)
for key, m in self.key_iterator(d, self.mode):
self.push_transform(d, key)
d[key] = self.padder(d[key], mode=m)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = transform[InverseKeys.ORIG_SIZE.value]
if self.padder.method == Method.SYMMETRIC:
current_size = d[key].shape[1:]
roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)]
else:
roi_center = [floor(r / 2) if r % 2 == 0 else (r - 1) // 2 for r in orig_size]

inverse_transform = SpatialCrop(roi_center, orig_size)
# Apply inverse transform
d[key] = inverse_transform(d[key])
# Remove the applied transform
self.pop_transform(d, key)

return d


class BorderPadd(MapTransform):
"""
Expand Down
113 changes: 113 additions & 0 deletions monai/transforms/inverse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Hashable, Optional, Tuple

import numpy as np

from monai.transforms.transform import RandomizableTransform, Transform
from monai.utils.enums import InverseKeys

__all__ = ["InvertibleTransform"]


class InvertibleTransform(Transform):
"""Classes for invertible transforms.

This class exists so that an ``invert`` method can be implemented. This allows, for
example, images to be cropped, rotated, padded, etc., during training and inference,
and after be returned to their original size before saving to file for comparison in
an external viewer.

When the ``__call__`` method is called, the transformation information for each key is
stored. If the transforms were applied to keys "image" and "label", there will be two
extra keys in the dictionary: "image_transforms" and "label_transforms". Each list
contains a list of the transforms applied to that key. When the ``inverse`` method is
called, the inverse is called on each key individually, which allows for different
parameters being passed to each label (e.g., different interpolation for image and
label).

When the ``inverse`` method is called, the inverse transforms are applied in a last-
in-first-out order. As the inverse is applied, its entry is removed from the list
detailing the applied transformations. That is to say that during the forward pass,
the list of applied transforms grows, and then during the inverse it shrinks back
down to an empty list.

The information in ``data[key_transform]`` will be compatible with the default collate
since it only stores strings, numbers and arrays.

We currently check that the ``id()`` of the transform is the same in the forward and
inverse directions. This is a useful check to ensure that the inverses are being
processed in the correct order. However, this may cause issues if the ``id()`` of the
object changes (such as multiprocessing on Windows). If you feel this issue affects
you, please raise a GitHub issue.

Note to developers: When converting a transform to an invertible transform, you need to:

#. Inherit from this class.
#. In ``__call__``, add a call to ``push_transform``.
#. Any extra information that might be needed for the inverse can be included with the
dictionary ``extra_info``. This dictionary should have the same keys regardless of
whether ``do_transform`` was `True` or `False` and can only contain objects that are
accepted in pytorch data loader's collate function (e.g., `None` is not allowed).
#. Implement an ``inverse`` method. Make sure that after performing the inverse,
``pop_transform`` is called.

"""

def push_transform(
self,
data: dict,
key: Hashable,
extra_info: Optional[dict] = None,
orig_size: Optional[Tuple] = None,
) -> None:
"""Append to list of applied transforms for that key."""
key_transform = str(key) + InverseKeys.KEY_SUFFIX.value
info = {
InverseKeys.CLASS_NAME.value: self.__class__.__name__,
InverseKeys.ID.value: id(self),
InverseKeys.ORIG_SIZE.value: orig_size or data[key].shape[1:],
}
if extra_info is not None:
info[InverseKeys.EXTRA_INFO.value] = extra_info
# If class is randomizable transform, store whether the transform was actually performed (based on `prob`)
if isinstance(self, RandomizableTransform):
info[InverseKeys.DO_TRANSFORM.value] = self._do_transform
# If this is the first, create list
if key_transform not in data:
data[key_transform] = []
data[key_transform].append(info)

def check_transforms_match(self, transform: dict) -> None:
"""Check transforms are of same instance."""
if transform[InverseKeys.ID.value] != id(self):
raise RuntimeError("Should inverse most recently applied invertible transform first")

def get_most_recent_transform(self, data: dict, key: Hashable) -> dict:
"""Get most recent transform."""
transform = dict(data[str(key) + InverseKeys.KEY_SUFFIX.value][-1])
self.check_transforms_match(transform)
return transform

def pop_transform(self, data: dict, key: Hashable) -> None:
"""Remove most recent transform."""
data[str(key) + InverseKeys.KEY_SUFFIX.value].pop()
Comment thread
wyli marked this conversation as resolved.

def inverse(self, data: dict) -> Dict[Hashable, np.ndarray]:
"""
Inverse of ``__call__``.

Raises:
NotImplementedError: When the subclass does not override this method.

"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
GridSampleMode,
GridSamplePadMode,
InterpolateMode,
InverseKeys,
LossReduction,
Method,
MetricReduction,
Expand Down
12 changes: 12 additions & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"ChannelMatching",
"SkipMode",
"Method",
"InverseKeys",
]


Expand Down Expand Up @@ -214,3 +215,14 @@ class Method(Enum):

SYMMETRIC = "symmetric"
END = "end"


class InverseKeys(Enum):
"""Extra meta data keys used for inverse transforms."""

CLASS_NAME = "class"
ID = "id"
ORIG_SIZE = "orig_size"
EXTRA_INFO = "extra_info"
DO_TRANSFORM = "do_transforms"
KEY_SUFFIX = "_transforms"
19 changes: 14 additions & 5 deletions tests/test_decollate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import unittest
from enum import Enum

import numpy as np
import torch
Expand All @@ -20,6 +22,7 @@
from monai.transforms import AddChanneld, Compose, LoadImaged, RandFlipd, SpatialPadd, ToTensord
from monai.transforms.post.dictionary import Decollated
from monai.utils import optional_import, set_determinism
from monai.utils.enums import InverseKeys
from tests.utils import make_nifti_image

_, has_nib = optional_import("nibabel")
Expand All @@ -46,14 +49,20 @@ def tearDown(self) -> None:
def check_match(self, in1, in2):
if isinstance(in1, dict):
self.assertTrue(isinstance(in2, dict))
self.check_match(list(in1.keys()), list(in2.keys()))
self.check_match(list(in1.values()), list(in2.values()))
elif any(isinstance(in1, i) for i in [list, tuple]):
for (k1, v1), (k2, v2) in zip(in1.items(), in2.items()):
if isinstance(k1, Enum) and isinstance(k2, Enum):
k1, k2 = k1.value, k2.value
self.check_match(k1, k2)
# Transform ids won't match for windows with multiprocessing, so don't check values
if k1 == InverseKeys.ID.value and sys.platform in ["darwin", "win32"]:
continue
self.check_match(v1, v2)
elif isinstance(in1, (list, tuple)):
for l1, l2 in zip(in1, in2):
self.check_match(l1, l2)
elif any(isinstance(in1, i) for i in [str, int]):
elif isinstance(in1, (str, int)):
self.assertEqual(in1, in2)
elif any(isinstance(in1, i) for i in [torch.Tensor, np.ndarray]):
elif isinstance(in1, (torch.Tensor, np.ndarray)):
np.testing.assert_array_equal(in1, in2)
else:
raise RuntimeError(f"Not sure how to compare types. type(in1): {type(in1)}, type(in2): {type(in2)}")
Expand Down
Loading