Skip to content
8 changes: 4 additions & 4 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def __call__(
ornt[:, 0] += 1 # skip channel dim
ornt = np.concatenate([np.array([[0, 1]]), ornt])
shape = data_array.shape[1:]
data_array = nib.orientations.apply_orientation(data_array, ornt)
data_array = np.ascontiguousarray(nib.orientations.apply_orientation(data_array, ornt))
new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, shape)
new_affine = to_affine_nd(affine, new_affine)
return data_array, affine, new_affine
Expand Down Expand Up @@ -590,7 +590,7 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None:
If axis is negative it counts from the last to the first axis.
"""
self.k = k
spatial_axes_ = ensure_tuple(spatial_axes)
spatial_axes_: Tuple[int, int] = ensure_tuple(spatial_axes) # type: ignore
if len(spatial_axes_) != 2:
raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.")
self.spatial_axes = spatial_axes_
Expand Down Expand Up @@ -620,7 +620,7 @@ def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: Tuple[int, i
spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.
Default: (0, 1), this is the first two axis in spatial dimensions.
"""
RandomizableTransform.__init__(self, min(max(prob, 0.0), 1.0))
RandomizableTransform.__init__(self, prob)
self.max_k = max_k
self.spatial_axes = spatial_axes

Expand Down Expand Up @@ -758,7 +758,7 @@ class RandFlip(RandomizableTransform):
"""

def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None:
RandomizableTransform.__init__(self, min(max(prob, 0.0), 1.0))
RandomizableTransform.__init__(self, prob)
self.flipper = Flip(spatial_axis=spatial_axis)

def __call__(self, img: np.ndarray) -> np.ndarray:
Expand Down
134 changes: 126 additions & 8 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Class names are ended with 'd' to denote dictionary-based transforms.
"""

from copy import deepcopy
from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
Expand All @@ -23,6 +24,7 @@
from monai.config import DtypeLike, KeysCollection
from monai.networks.layers.simplelayers import GaussianFilter
from monai.transforms.croppad.array import CenterSpatialCrop
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.spatial.array import (
Affine,
Flip,
Expand All @@ -47,6 +49,10 @@
ensure_tuple_rep,
fall_back_tuple,
)
from monai.utils.enums import InverseKeys
from monai.utils.module import optional_import

nib, _ = optional_import("nibabel")

__all__ = [
"Spacingd",
Expand Down Expand Up @@ -204,7 +210,7 @@ def __call__(
return d


class Orientationd(MapTransform):
class Orientationd(MapTransform, InvertibleTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.Orientation`.

Expand Down Expand Up @@ -259,13 +265,36 @@ def __call__(
) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]:
d: Dict = dict(data)
for key in self.key_iterator(d):
meta_data = d[f"{key}_{self.meta_key_postfix}"]
d[key], _, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"])
meta_data_key = f"{key}_{self.meta_key_postfix}"
meta_data = d[meta_data_key]
d[key], old_affine, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"])
self.push_transform(d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine})
d[meta_data_key]["affine"] = new_affine
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
meta_data = d[transform[InverseKeys.EXTRA_INFO.value]["meta_data_key"]]
orig_affine = transform[InverseKeys.EXTRA_INFO.value]["old_affine"]
orig_axcodes = nib.orientations.aff2axcodes(orig_affine)
inverse_transform = Orientation(
axcodes=orig_axcodes,
as_closest_canonical=False,
labels=self.ornt_transform.labels,
)
# Apply inverse
d[key], _, new_affine = inverse_transform(d[key], affine=meta_data["affine"])
meta_data["affine"] = new_affine
# Remove the applied transform
self.pop_transform(d, key)

return d


class Rotate90d(MapTransform):
class Rotate90d(MapTransform, InvertibleTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`.
"""
Expand All @@ -286,11 +315,31 @@ def __init__(
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = dict(data)
for key in self.key_iterator(d):
self.push_transform(d, key)
d[key] = self.rotator(d[key])
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
_ = self.get_most_recent_transform(d, key)
# Create inverse transform
spatial_axes = self.rotator.spatial_axes
Comment thread
wyli marked this conversation as resolved.
num_times_rotated = self.rotator.k
num_times_to_rotate = 4 - num_times_rotated
inverse_transform = Rotate90(num_times_to_rotate, spatial_axes)
# Might need to convert to numpy
if isinstance(d[key], torch.Tensor):
d[key] = torch.Tensor(d[key]).cpu().numpy()
# Apply inverse
d[key] = inverse_transform(d[key])
# Remove the applied transform
self.pop_transform(d, key)

return d


class RandRotate90d(RandomizableTransform, MapTransform):
class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform):
"""
Dictionary-based version :py:class:`monai.transforms.RandRotate90`.
With probability `prob`, input arrays are rotated by 90 degrees
Expand Down Expand Up @@ -337,6 +386,27 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np.
for key in self.key_iterator(d):
if self._do_transform:
d[key] = rotator(d[key])
self.push_transform(d, key, extra_info={"rand_k": self._rand_k})
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Check if random transform was actually performed (based on `prob`)
if transform[InverseKeys.DO_TRANSFORM.value]:
# Create inverse transform
num_times_rotated = transform[InverseKeys.EXTRA_INFO.value]["rand_k"]
num_times_to_rotate = 4 - num_times_rotated
Comment thread
wyli marked this conversation as resolved.
inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes)
# Might need to convert to numpy
if isinstance(d[key], torch.Tensor):
d[key] = torch.Tensor(d[key]).cpu().numpy()
# Apply inverse
d[key] = inverse_transform(d[key])
# Remove the applied transform
self.pop_transform(d, key)

return d


Expand Down Expand Up @@ -789,7 +859,7 @@ def __call__(
return d


class Flipd(MapTransform):
class Flipd(MapTransform, InvertibleTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.Flip`.

Expand All @@ -814,11 +884,26 @@ def __init__(
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = dict(data)
for key in self.key_iterator(d):
self.push_transform(d, key)
d[key] = self.flipper(d[key])
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
_ = self.get_most_recent_transform(d, key)
# Might need to convert to numpy
if isinstance(d[key], torch.Tensor):
d[key] = torch.Tensor(d[key]).cpu().numpy()
# Inverse is same as forward
d[key] = self.flipper(d[key])
# Remove the applied transform
self.pop_transform(d, key)

class RandFlipd(RandomizableTransform, MapTransform):
return d


class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform):
"""
Dictionary-based version :py:class:`monai.transforms.RandFlip`.

Expand Down Expand Up @@ -851,10 +936,26 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
for key in self.key_iterator(d):
if self._do_transform:
d[key] = self.flipper(d[key])
self.push_transform(d, key)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Check if random transform was actually performed (based on `prob`)
if transform[InverseKeys.DO_TRANSFORM.value]:
# Might need to convert to numpy
if isinstance(d[key], torch.Tensor):
d[key] = torch.Tensor(d[key]).cpu().numpy()
# Inverse is same as forward
d[key] = self.flipper(d[key])
# Remove the applied transform
self.pop_transform(d, key)
return d


class RandAxisFlipd(RandomizableTransform, MapTransform):
class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform):
"""
Dictionary-based version :py:class:`monai.transforms.RandAxisFlip`.

Expand Down Expand Up @@ -885,6 +986,23 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
for key in self.key_iterator(d):
if self._do_transform:
d[key] = flipper(d[key])
self.push_transform(d, key, extra_info={"axis": self._axis})
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Check if random transform was actually performed (based on `prob`)
if transform[InverseKeys.DO_TRANSFORM.value]:
flipper = Flip(spatial_axis=transform[InverseKeys.EXTRA_INFO.value]["axis"])
# Might need to convert to numpy
if isinstance(d[key], torch.Tensor):
d[key] = torch.Tensor(d[key]).cpu().numpy()
# Inverse is same as forward
d[key] = flipper(d[key])
# Remove the applied transform
self.pop_transform(d, key)
return d


Expand Down
79 changes: 79 additions & 0 deletions tests/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,18 @@
Compose,
CropForegroundd,
DivisiblePadd,
Flipd,
InvertibleTransform,
LoadImaged,
Orientationd,
RandAxisFlipd,
RandFlipd,
Randomizable,
RandRotate90d,
RandSpatialCropd,
ResizeWithPadOrCrop,
ResizeWithPadOrCropd,
Rotate90d,
SpatialCropd,
SpatialPadd,
allow_missing_keys_mode,
Expand Down Expand Up @@ -207,6 +213,79 @@

TESTS.append(("ResizeWithPadOrCropd 3d", "3D", 0, ResizeWithPadOrCropd(KEYS, [201, 150, 105])))

TESTS.append(
(
"Flipd 3d",
"3D",
0,
Flipd(KEYS, [1, 2]),
)
)

TESTS.append(
(
"Flipd 3d",
"3D",
0,
Flipd(KEYS, [1, 2]),
)
)

TESTS.append(
(
"RandFlipd 3d",
"3D",
0,
RandFlipd(KEYS, 1, [1, 2]),
)
)

TESTS.append(
(
"RandAxisFlipd 3d",
"3D",
0,
RandAxisFlipd(KEYS, 1),
)
)

for acc in [True, False]:
TESTS.append(
(
"Orientationd 3d",
"3D",
0,
Orientationd(KEYS, "RAS", as_closest_canonical=acc),
)
)

TESTS.append(
(
"Rotate90d 2d",
"2D",
0,
Rotate90d(KEYS),
)
)

TESTS.append(
(
"Rotate90d 3d",
"3D",
0,
Rotate90d(KEYS, k=2, spatial_axes=(1, 2)),
)
)

TESTS.append(
(
"RandRotate90d 3d",
"3D",
0,
RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)),
)
)

TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS]

TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore
Expand Down