diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 2867361b8e..33b8da3ebb 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -281,7 +281,7 @@ def __call__( ornt[:, 0] += 1 # skip channel dim ornt = np.concatenate([np.array([[0, 1]]), ornt]) shape = data_array.shape[1:] - data_array = nib.orientations.apply_orientation(data_array, ornt) + data_array = np.ascontiguousarray(nib.orientations.apply_orientation(data_array, ornt)) new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, shape) new_affine = to_affine_nd(affine, new_affine) return data_array, affine, new_affine @@ -590,7 +590,7 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: If axis is negative it counts from the last to the first axis. """ self.k = k - spatial_axes_ = ensure_tuple(spatial_axes) + spatial_axes_: Tuple[int, int] = ensure_tuple(spatial_axes) # type: ignore if len(spatial_axes_) != 2: raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.") self.spatial_axes = spatial_axes_ @@ -620,7 +620,7 @@ def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: Tuple[int, i spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. """ - RandomizableTransform.__init__(self, min(max(prob, 0.0), 1.0)) + RandomizableTransform.__init__(self, prob) self.max_k = max_k self.spatial_axes = spatial_axes @@ -758,7 +758,7 @@ class RandFlip(RandomizableTransform): """ def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: - RandomizableTransform.__init__(self, min(max(prob, 0.0), 1.0)) + RandomizableTransform.__init__(self, prob) self.flipper = Flip(spatial_axis=spatial_axis) def __call__(self, img: np.ndarray) -> np.ndarray: diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index d9d38242fb..170006ed2b 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -15,6 +15,7 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +from copy import deepcopy from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union import numpy as np @@ -23,6 +24,7 @@ from monai.config import DtypeLike, KeysCollection from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.croppad.array import CenterSpatialCrop +from monai.transforms.inverse import InvertibleTransform from monai.transforms.spatial.array import ( Affine, Flip, @@ -47,6 +49,10 @@ ensure_tuple_rep, fall_back_tuple, ) +from monai.utils.enums import InverseKeys +from monai.utils.module import optional_import + +nib, _ = optional_import("nibabel") __all__ = [ "Spacingd", @@ -204,7 +210,7 @@ def __call__( return d -class Orientationd(MapTransform): +class Orientationd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Orientation`. @@ -259,13 +265,36 @@ def __call__( ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: d: Dict = dict(data) for key in self.key_iterator(d): - meta_data = d[f"{key}_{self.meta_key_postfix}"] - d[key], _, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) + meta_data_key = f"{key}_{self.meta_key_postfix}" + meta_data = d[meta_data_key] + d[key], old_affine, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) + self.push_transform(d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine}) + d[meta_data_key]["affine"] = new_affine + return d + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + meta_data = d[transform[InverseKeys.EXTRA_INFO.value]["meta_data_key"]] + orig_affine = transform[InverseKeys.EXTRA_INFO.value]["old_affine"] + orig_axcodes = nib.orientations.aff2axcodes(orig_affine) + inverse_transform = Orientation( + axcodes=orig_axcodes, + as_closest_canonical=False, + labels=self.ornt_transform.labels, + ) + # Apply inverse + d[key], _, new_affine = inverse_transform(d[key], affine=meta_data["affine"]) meta_data["affine"] = new_affine + # Remove the applied transform + self.pop_transform(d, key) + return d -class Rotate90d(MapTransform): +class Rotate90d(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`. """ @@ -286,11 +315,31 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.key_iterator(d): + self.push_transform(d, key) d[key] = self.rotator(d[key]) return d + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + _ = self.get_most_recent_transform(d, key) + # Create inverse transform + spatial_axes = self.rotator.spatial_axes + num_times_rotated = self.rotator.k + num_times_to_rotate = 4 - num_times_rotated + inverse_transform = Rotate90(num_times_to_rotate, spatial_axes) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Apply inverse + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + + return d + -class RandRotate90d(RandomizableTransform, MapTransform): +class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate90`. With probability `prob`, input arrays are rotated by 90 degrees @@ -337,6 +386,27 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. for key in self.key_iterator(d): if self._do_transform: d[key] = rotator(d[key]) + self.push_transform(d, key, extra_info={"rand_k": self._rand_k}) + return d + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform[InverseKeys.DO_TRANSFORM.value]: + # Create inverse transform + num_times_rotated = transform[InverseKeys.EXTRA_INFO.value]["rand_k"] + num_times_to_rotate = 4 - num_times_rotated + inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Apply inverse + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + return d @@ -789,7 +859,7 @@ def __call__( return d -class Flipd(MapTransform): +class Flipd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Flip`. @@ -814,11 +884,26 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.key_iterator(d): + self.push_transform(d, key) d[key] = self.flipper(d[key]) return d + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + _ = self.get_most_recent_transform(d, key) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Inverse is same as forward + d[key] = self.flipper(d[key]) + # Remove the applied transform + self.pop_transform(d, key) -class RandFlipd(RandomizableTransform, MapTransform): + return d + + +class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandFlip`. @@ -851,10 +936,26 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key in self.key_iterator(d): if self._do_transform: d[key] = self.flipper(d[key]) + self.push_transform(d, key) + return d + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform[InverseKeys.DO_TRANSFORM.value]: + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Inverse is same as forward + d[key] = self.flipper(d[key]) + # Remove the applied transform + self.pop_transform(d, key) return d -class RandAxisFlipd(RandomizableTransform, MapTransform): +class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandAxisFlip`. @@ -885,6 +986,23 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key in self.key_iterator(d): if self._do_transform: d[key] = flipper(d[key]) + self.push_transform(d, key, extra_info={"axis": self._axis}) + return d + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform[InverseKeys.DO_TRANSFORM.value]: + flipper = Flip(spatial_axis=transform[InverseKeys.EXTRA_INFO.value]["axis"]) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Inverse is same as forward + d[key] = flipper(d[key]) + # Remove the applied transform + self.pop_transform(d, key) return d diff --git a/tests/test_inverse.py b/tests/test_inverse.py index bb2d997eb5..0c29ea7b08 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -28,12 +28,18 @@ Compose, CropForegroundd, DivisiblePadd, + Flipd, InvertibleTransform, LoadImaged, + Orientationd, + RandAxisFlipd, + RandFlipd, Randomizable, + RandRotate90d, RandSpatialCropd, ResizeWithPadOrCrop, ResizeWithPadOrCropd, + Rotate90d, SpatialCropd, SpatialPadd, allow_missing_keys_mode, @@ -207,6 +213,79 @@ TESTS.append(("ResizeWithPadOrCropd 3d", "3D", 0, ResizeWithPadOrCropd(KEYS, [201, 150, 105]))) +TESTS.append( + ( + "Flipd 3d", + "3D", + 0, + Flipd(KEYS, [1, 2]), + ) +) + +TESTS.append( + ( + "Flipd 3d", + "3D", + 0, + Flipd(KEYS, [1, 2]), + ) +) + +TESTS.append( + ( + "RandFlipd 3d", + "3D", + 0, + RandFlipd(KEYS, 1, [1, 2]), + ) +) + +TESTS.append( + ( + "RandAxisFlipd 3d", + "3D", + 0, + RandAxisFlipd(KEYS, 1), + ) +) + +for acc in [True, False]: + TESTS.append( + ( + "Orientationd 3d", + "3D", + 0, + Orientationd(KEYS, "RAS", as_closest_canonical=acc), + ) + ) + +TESTS.append( + ( + "Rotate90d 2d", + "2D", + 0, + Rotate90d(KEYS), + ) +) + +TESTS.append( + ( + "Rotate90d 3d", + "3D", + 0, + Rotate90d(KEYS, k=2, spatial_axes=(1, 2)), + ) +) + +TESTS.append( + ( + "RandRotate90d 3d", + "3D", + 0, + RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)), + ) +) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore