Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 55 additions & 3 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from monai.config import DtypeLike, KeysCollection
from monai.networks.layers.simplelayers import GaussianFilter
from monai.transforms.croppad.array import CenterSpatialCrop
from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.spatial.array import (
Affine,
Expand Down Expand Up @@ -1215,7 +1215,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
return d


class Zoomd(MapTransform):
class Zoomd(MapTransform, InvertibleTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.Zoom`.

Expand Down Expand Up @@ -1261,6 +1261,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
for key, mode, padding_mode, align_corners in self.key_iterator(
d, self.mode, self.padding_mode, self.align_corners
):
self.push_transform(d, key)
d[key] = self.zoomer(
d[key],
mode=mode,
Expand All @@ -1269,8 +1270,31 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = deepcopy(dict(data))
for key, mode, padding_mode, align_corners in self.key_iterator(
d, self.mode, self.padding_mode, self.align_corners
):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
zoom = np.array(self.zoomer.zoom)
inverse_transform = Zoom(zoom=1 / zoom, keep_size=self.zoomer.keep_size)
# Apply inverse
d[key] = inverse_transform(
d[key],
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners,
)
# Size might be out by 1 voxel so pad
d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE.value])(d[key])
# Remove the applied transform
self.pop_transform(d, key)

return d

class RandZoomd(RandomizableTransform, MapTransform):

class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform):
"""
Dict-based version :py:class:`monai.transforms.RandZoom`.

Expand Down Expand Up @@ -1338,6 +1362,8 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
self.randomize()
d = dict(data)
if not self._do_transform:
for key in self.keys:
self.push_transform(d, key, extra_info={"zoom": self._zoom})

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps extra_info self._zoom could have different sizes (on/off _do_transform) in a batch loader, then they are not compatible with the collate #1798

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you're right. You can see there's a few ensure_tuple(self._zoom) that come a few lines later, so they just need to be moved ahead.

return d

img_dims = data[self.keys[0]].ndim
Expand All @@ -1351,6 +1377,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
for key, mode, padding_mode, align_corners in self.key_iterator(
d, self.mode, self.padding_mode, self.align_corners
):
self.push_transform(d, key, extra_info={"zoom": self._zoom})
d[key] = zoomer(
d[key],
mode=mode,
Expand All @@ -1359,6 +1386,31 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = deepcopy(dict(data))
for key, mode, padding_mode, align_corners in self.key_iterator(
d, self.mode, self.padding_mode, self.align_corners
):
transform = self.get_most_recent_transform(d, key)
# Check if random transform was actually performed (based on `prob`)
if transform[InverseKeys.DO_TRANSFORM.value]:
# Create inverse transform
zoom = np.array(transform[InverseKeys.EXTRA_INFO.value]["zoom"])
inverse_transform = Zoom(zoom=1 / zoom, keep_size=self.keep_size)
# Apply inverse
d[key] = inverse_transform(
d[key],
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners,
)
# Size might be out by 1 voxel so pad
d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE.value])(d[key])
# Remove the applied transform
self.pop_transform(d, key)

return d


SpacingD = SpacingDict = Spacingd
OrientationD = OrientationDict = Orientationd
Expand Down
32 changes: 32 additions & 0 deletions tests/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@
Randomizable,
RandRotate90d,
RandSpatialCropd,
RandZoomd,
Resized,
ResizeWithPadOrCrop,
ResizeWithPadOrCropd,
Rotate90d,
Spacingd,
SpatialCropd,
SpatialPadd,
Zoomd,
allow_missing_keys_mode,
)
from monai.utils import first, get_seed, optional_import, set_determinism
Expand Down Expand Up @@ -294,6 +296,36 @@

TESTS.append(("Resized 3d", "3D", 5e-2, Resized(KEYS, [201, 150, 78])))


TESTS.append(
(
"Zoomd 1d",
"1D odd",
0,
Zoomd(KEYS, zoom=2, keep_size=False),
)
)

TESTS.append(
(
"Zoomd 2d",
"2D",
2e-1,
Zoomd(KEYS, zoom=0.9),
)
)

TESTS.append(
(
"Zoomd 3d",
"3D",
3e-2,
Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False),
)
)

TESTS.append(("RandZoom 3d", "3D", 9e-2, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True)))

TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS]

TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore
Expand Down