From ea31dd771beccccac660d6453dd4204c2dde7b00 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 15 Mar 2021 12:30:06 +0000 Subject: [PATCH 1/4] lossless inverse Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/array.py | 8 +- monai/transforms/spatial/dictionary.py | 135 +++++++++++++++++++++++-- tests/test_inverse.py | 102 +++++++++++++++++++ 3 files changed, 232 insertions(+), 13 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 2867361b8e..33b8da3ebb 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -281,7 +281,7 @@ def __call__( ornt[:, 0] += 1 # skip channel dim ornt = np.concatenate([np.array([[0, 1]]), ornt]) shape = data_array.shape[1:] - data_array = nib.orientations.apply_orientation(data_array, ornt) + data_array = np.ascontiguousarray(nib.orientations.apply_orientation(data_array, ornt)) new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, shape) new_affine = to_affine_nd(affine, new_affine) return data_array, affine, new_affine @@ -590,7 +590,7 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: If axis is negative it counts from the last to the first axis. """ self.k = k - spatial_axes_ = ensure_tuple(spatial_axes) + spatial_axes_: Tuple[int, int] = ensure_tuple(spatial_axes) # type: ignore if len(spatial_axes_) != 2: raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.") self.spatial_axes = spatial_axes_ @@ -620,7 +620,7 @@ def __init__(self, prob: float = 0.1, max_k: int = 3, spatial_axes: Tuple[int, i spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. Default: (0, 1), this is the first two axis in spatial dimensions. """ - RandomizableTransform.__init__(self, min(max(prob, 0.0), 1.0)) + RandomizableTransform.__init__(self, prob) self.max_k = max_k self.spatial_axes = spatial_axes @@ -758,7 +758,7 @@ class RandFlip(RandomizableTransform): """ def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: - RandomizableTransform.__init__(self, min(max(prob, 0.0), 1.0)) + RandomizableTransform.__init__(self, prob) self.flipper = Flip(spatial_axis=spatial_axis) def __call__(self, img: np.ndarray) -> np.ndarray: diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index d9d38242fb..7cc10ccbfd 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -15,16 +15,20 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +from copy import deepcopy from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch from monai.config import DtypeLike, KeysCollection +from monai.networks.layers import AffineTransform from monai.networks.layers.simplelayers import GaussianFilter -from monai.transforms.croppad.array import CenterSpatialCrop +from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad +from monai.transforms.inverse import InvertibleTransform, NonRigidTransform from monai.transforms.spatial.array import ( Affine, + AffineGrid, Flip, Orientation, Rand2DElastic, @@ -47,6 +51,7 @@ ensure_tuple_rep, fall_back_tuple, ) +from monai.utils.enums import InverseKeys __all__ = [ "Spacingd", @@ -204,7 +209,7 @@ def __call__( return d -class Orientationd(MapTransform): +class Orientationd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Orientation`. @@ -259,13 +264,36 @@ def __call__( ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: d: Dict = dict(data) for key in self.key_iterator(d): - meta_data = d[f"{key}_{self.meta_key_postfix}"] - d[key], _, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) + meta_data_key = f"{key}_{self.meta_key_postfix}" + meta_data = d[meta_data_key] + d[key], old_affine, new_affine = self.ornt_transform(d[key], affine=meta_data["affine"]) + self.push_transform(d, key, extra_info={"meta_data_key": meta_data_key, "old_affine": old_affine}) + d[meta_data_key]["affine"] = new_affine + return d + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + meta_data = d[transform[InverseKeys.EXTRA_INFO.value]["meta_data_key"]] + orig_affine = transform[InverseKeys.EXTRA_INFO.value]["old_affine"] + orig_axcodes = nib.orientations.aff2axcodes(orig_affine) + inverse_transform = Orientation( + axcodes=orig_axcodes, + as_closest_canonical=self.ornt_transform.as_closest_canonical, + labels=self.ornt_transform.labels, + ) + # Apply inverse + d[key], _, new_affine = inverse_transform(d[key], affine=meta_data["affine"]) meta_data["affine"] = new_affine + # Remove the applied transform + self.pop_transform(d, key) + return d -class Rotate90d(MapTransform): +class Rotate90d(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`. """ @@ -286,11 +314,31 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.key_iterator(d): + self.push_transform(d, key) d[key] = self.rotator(d[key]) return d + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + _ = self.get_most_recent_transform(d, key) + # Create inverse transform + spatial_axes = self.rotator.spatial_axes + num_times_rotated = self.rotator.k + num_times_to_rotate = 4 - num_times_rotated + inverse_transform = Rotate90(num_times_to_rotate, spatial_axes) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Apply inverse + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + + return d + -class RandRotate90d(RandomizableTransform, MapTransform): +class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate90`. With probability `prob`, input arrays are rotated by 90 degrees @@ -337,6 +385,27 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Mapping[Hashable, np. for key in self.key_iterator(d): if self._do_transform: d[key] = rotator(d[key]) + self.push_transform(d, key, extra_info={"rand_k": self._rand_k}) + return d + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform[InverseKeys.DO_TRANSFORM.value]: + # Create inverse transform + num_times_rotated = transform[InverseKeys.EXTRA_INFO.value]["rand_k"] + num_times_to_rotate = 4 - num_times_rotated + inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Apply inverse + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + return d @@ -789,7 +858,7 @@ def __call__( return d -class Flipd(MapTransform): +class Flipd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Flip`. @@ -814,11 +883,26 @@ def __init__( def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.key_iterator(d): + self.push_transform(d, key) d[key] = self.flipper(d[key]) return d + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + _ = self.get_most_recent_transform(d, key) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Inverse is same as forward + d[key] = self.flipper(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + + return d -class RandFlipd(RandomizableTransform, MapTransform): + +class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandFlip`. @@ -851,10 +935,26 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key in self.key_iterator(d): if self._do_transform: d[key] = self.flipper(d[key]) + self.push_transform(d, key) + return d + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform[InverseKeys.DO_TRANSFORM.value]: + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Inverse is same as forward + d[key] = self.flipper(d[key]) + # Remove the applied transform + self.pop_transform(d, key) return d -class RandAxisFlipd(RandomizableTransform, MapTransform): +class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandAxisFlip`. @@ -885,6 +985,23 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda for key in self.key_iterator(d): if self._do_transform: d[key] = flipper(d[key]) + self.push_transform(d, key, extra_info={"axis": self._axis}) + return d + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Check if random transform was actually performed (based on `prob`) + if transform[InverseKeys.DO_TRANSFORM.value]: + flipper = Flip(spatial_axis=transform[InverseKeys.EXTRA_INFO.value]["axis"]) + # Might need to convert to numpy + if isinstance(d[key], torch.Tensor): + d[key] = torch.Tensor(d[key]).cpu().numpy() + # Inverse is same as forward + d[key] = flipper(d[key]) + # Remove the applied transform + self.pop_transform(d, key) return d diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 6635a4126f..1f81d71b58 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random import sys import unittest from functools import partial @@ -28,11 +29,18 @@ Compose, CropForegroundd, DivisiblePadd, + Flipd, InvertibleTransform, LoadImaged, + Orientationd, + RandAxisFlipd, + RandFlipd, + RandRotate90d, + RandRotated, RandSpatialCropd, ResizeWithPadOrCrop, ResizeWithPadOrCropd, + Rotate90d, SpatialCropd, SpatialPadd, allow_missing_keys_mode, @@ -206,6 +214,100 @@ TESTS.append(("ResizeWithPadOrCropd 3d", "3D", 0, ResizeWithPadOrCropd(KEYS, [201, 150, 105]))) +TESTS.append( + ( + "RandRotated, prob 0", + "2D", + 0, + RandRotated(KEYS, prob=0), + ) +) + +TESTS.append( + ( + "Flipd 3d", + "3D", + 0, + Flipd(KEYS, [1, 2]), + ) +) + +TESTS.append( + ( + "Flipd 3d", + "3D", + 0, + Flipd(KEYS, [1, 2]), + ) +) + +TESTS.append( + ( + "RandFlipd 3d", + "3D", + 0, + RandFlipd(KEYS, 1, [1, 2]), + ) +) + +TESTS.append( + ( + "RandAxisFlipd 3d", + "3D", + 0, + RandAxisFlipd(KEYS, 1), + ) +) + +TESTS.append( + ( + "RandRotated 3d", + "3D", + 1e-1, + RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1), # type: ignore + ) +) + +TESTS.append( + ( + "Orientationd 3d", + "3D", + 0, + # For data loader, output needs to be same size, so input must be square/cubic + SpatialPadd(KEYS, 110), + Orientationd(KEYS, "RAS"), + ) +) + +TESTS.append( + ( + "Rotate90d 2d", + "2D", + 0, + Rotate90d(KEYS), + ) +) + +TESTS.append( + ( + "Rotate90d 3d", + "3D", + 0, + Rotate90d(KEYS, k=2, spatial_axes=(1, 2)), + ) +) + +TESTS.append( + ( + "RandRotate90d 3d", + "3D", + 0, + # For data loader, output needs to be same size, so input must be square/cubic + SpatialPadd(KEYS, 110), + RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)), + ) +) + TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore From 28f5a36615275978d0a7c71d236d26cb23beb048 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 15 Mar 2021 12:42:28 +0000 Subject: [PATCH 2/4] code format Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 7cc10ccbfd..68cb18ae18 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -22,13 +22,11 @@ import torch from monai.config import DtypeLike, KeysCollection -from monai.networks.layers import AffineTransform from monai.networks.layers.simplelayers import GaussianFilter -from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad -from monai.transforms.inverse import InvertibleTransform, NonRigidTransform +from monai.transforms.croppad.array import CenterSpatialCrop +from monai.transforms.inverse import InvertibleTransform from monai.transforms.spatial.array import ( Affine, - AffineGrid, Flip, Orientation, Rand2DElastic, @@ -52,6 +50,9 @@ fall_back_tuple, ) from monai.utils.enums import InverseKeys +from monai.utils.module import optional_import + +nib, _ = optional_import("nibabel") __all__ = [ "Spacingd", From fed515e2c2ad4c0bd348711df2b98fe1b430bf7e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 15 Mar 2021 15:32:47 +0000 Subject: [PATCH 3/4] remove extra tests Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_inverse.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 1f81d71b58..d3a5a533cd 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import random import sys import unittest from functools import partial @@ -36,7 +35,6 @@ RandAxisFlipd, RandFlipd, RandRotate90d, - RandRotated, RandSpatialCropd, ResizeWithPadOrCrop, ResizeWithPadOrCropd, @@ -214,15 +212,6 @@ TESTS.append(("ResizeWithPadOrCropd 3d", "3D", 0, ResizeWithPadOrCropd(KEYS, [201, 150, 105]))) -TESTS.append( - ( - "RandRotated, prob 0", - "2D", - 0, - RandRotated(KEYS, prob=0), - ) -) - TESTS.append( ( "Flipd 3d", @@ -259,15 +248,6 @@ ) ) -TESTS.append( - ( - "RandRotated 3d", - "3D", - 1e-1, - RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1), # type: ignore - ) -) - TESTS.append( ( "Orientationd 3d", From c72407ab83432b1e7ebbe3b1e8beb691c4d052d4 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 16 Mar 2021 08:56:06 +0000 Subject: [PATCH 4/4] update tests Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/spatial/dictionary.py | 2 +- tests/test_inverse.py | 19 ++++++++----------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 68cb18ae18..170006ed2b 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -282,7 +282,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar orig_axcodes = nib.orientations.aff2axcodes(orig_affine) inverse_transform = Orientation( axcodes=orig_axcodes, - as_closest_canonical=self.ornt_transform.as_closest_canonical, + as_closest_canonical=False, labels=self.ornt_transform.labels, ) # Apply inverse diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 53c95f11b7..0c29ea7b08 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -249,16 +249,15 @@ ) ) -TESTS.append( - ( - "Orientationd 3d", - "3D", - 0, - # For data loader, output needs to be same size, so input must be square/cubic - SpatialPadd(KEYS, 110), - Orientationd(KEYS, "RAS"), +for acc in [True, False]: + TESTS.append( + ( + "Orientationd 3d", + "3D", + 0, + Orientationd(KEYS, "RAS", as_closest_canonical=acc), + ) ) -) TESTS.append( ( @@ -283,8 +282,6 @@ "RandRotate90d 3d", "3D", 0, - # For data loader, output needs to be same size, so input must be square/cubic - SpatialPadd(KEYS, 110), RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)), ) )