Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 83 additions & 8 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Class names are ended with 'd' to denote dictionary-based transforms.
"""

import contextlib
from copy import deepcopy
from enum import Enum
from itertools import chain
Expand All @@ -38,6 +39,7 @@
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import MapTransform, Randomizable
from monai.transforms.utils import (
allow_missing_keys_mode,
generate_pos_neg_label_crop_centers,
is_positive,
map_binary_to_indices,
Expand Down Expand Up @@ -586,7 +588,15 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
return super().__call__(data=data)


class RandSpatialCropSamplesd(Randomizable, MapTransform):
@contextlib.contextmanager
def _nullcontext(x):
"""
This is just like contextlib.nullcontext but also works in Python 3.6.
"""
yield x


class RandSpatialCropSamplesd(Randomizable, MapTransform, InvertibleTransform):
"""
Dictionary-based version :py:class:`monai.transforms.RandSpatialCropSamples`.
Crop image with random size or specific size ROI to generate a list of N samples.
Expand Down Expand Up @@ -664,6 +674,10 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n
for key in set(data.keys()).difference(set(self.keys)):
d[key] = deepcopy(data[key])
cropped = self.cropper(d)
# self.cropper will have added RandSpatialCropd to the list. Change to RandSpatialCropSamplesd
for key in self.key_iterator(cropped):
cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.CLASS_NAME] = self.__class__.__name__
cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.ID] = id(self)
# add `patch_index` to the meta data
for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
meta_key = meta_key or f"{key}_{meta_key_postfix}"
Expand All @@ -673,6 +687,17 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n
ret.append(cropped)
return ret

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = deepcopy(dict(data))
# We changed the transform name from RandSpatialCropd to RandSpatialCropSamplesd
# Need to revert that since we're calling RandSpatialCropd's inverse
for key in self.key_iterator(d):
d[key + InverseKeys.KEY_SUFFIX][-1][InverseKeys.CLASS_NAME] = self.cropper.__class__.__name__
d[key + InverseKeys.KEY_SUFFIX][-1][InverseKeys.ID] = id(self.cropper)
context_manager = allow_missing_keys_mode if self.allow_missing_keys else _nullcontext
with context_manager(self.cropper):
return self.cropper.inverse(d)


class CropForegroundd(MapTransform, InvertibleTransform):
"""
Expand Down Expand Up @@ -770,7 +795,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
return d


class RandWeightedCropd(Randomizable, MapTransform):
class RandWeightedCropd(Randomizable, MapTransform, InvertibleTransform):
"""
Samples a list of `num_samples` image patches according to the provided `weight_map`.

Expand Down Expand Up @@ -831,6 +856,10 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n
_spatial_size = fall_back_tuple(self.spatial_size, d[self.w_key].shape[1:])

results: List[Dict[Hashable, np.ndarray]] = [{} for _ in range(self.num_samples)]
# fill in the extra keys with unmodified data
for i in range(self.num_samples):
for key in set(data.keys()).difference(set(self.keys)):
results[i][key] = deepcopy(data[key])
for key in self.key_iterator(d):
img = d[key]
if img.shape[1:] != d[self.w_key].shape[1:]:
Expand All @@ -840,13 +869,13 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n
)
for i, center in enumerate(self.centers):
cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size)
orig_size = img.shape[1:]
results[i][key] = cropper(img)
self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size)
if self.center_coord_key:
results[i][self.center_coord_key] = center
# fill in the extra keys with unmodified data
for i in range(self.num_samples):
for key in set(data.keys()).difference(set(self.keys)):
results[i][key] = deepcopy(data[key])
# add `patch_index` to the meta data
for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
meta_key = meta_key or f"{key}_{meta_key_postfix}"
Expand All @@ -856,8 +885,30 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n

return results

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE])
current_size = np.asarray(d[key].shape[1:])
center = transform[InverseKeys.EXTRA_INFO]["center"]
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size)
# get required pad to start and end
pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)])
pad_to_end = orig_size - current_size - pad_to_start
# interleave mins and maxes
pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist())))
inverse_transform = BorderPad(pad)
# Apply inverse transform
d[key] = inverse_transform(d[key])
# Remove the applied transform
self.pop_transform(d, key)

return d


class RandCropByPosNegLabeld(Randomizable, MapTransform):
class RandCropByPosNegLabeld(Randomizable, MapTransform, InvertibleTransform):
"""
Dictionary-based version :py:class:`monai.transforms.RandCropByPosNegLabel`.
Crop random fixed sized regions with the center being a foreground or background voxel
Expand Down Expand Up @@ -973,13 +1024,15 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n
results: List[Dict[Hashable, np.ndarray]] = [{} for _ in range(self.num_samples)]

for i, center in enumerate(self.centers):
# fill in the extra keys with unmodified data
for key in set(data.keys()).difference(set(self.keys)):
results[i][key] = deepcopy(data[key])
for key in self.key_iterator(d):
img = d[key]
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore
orig_size = img.shape[1:]
results[i][key] = cropper(img)
# fill in the extra keys with unmodified data
for key in set(data.keys()).difference(set(self.keys)):
results[i][key] = deepcopy(data[key])
self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size)
# add `patch_index` to the meta data
for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
meta_key = meta_key or f"{key}_{meta_key_postfix}"
Expand All @@ -989,6 +1042,28 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n

return results

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE])
current_size = np.asarray(d[key].shape[1:])
center = transform[InverseKeys.EXTRA_INFO]["center"]
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore
# get required pad to start and end
pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)])
pad_to_end = orig_size - current_size - pad_to_start
# interleave mins and maxes
pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist())))
inverse_transform = BorderPad(pad)
# Apply inverse transform
d[key] = inverse_transform(d[key])
# Remove the applied transform
self.pop_transform(d, key)

return d


class ResizeWithPadOrCropd(MapTransform, InvertibleTransform):
"""
Expand Down
61 changes: 57 additions & 4 deletions tests/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@
Orientationd,
RandAffined,
RandAxisFlipd,
RandCropByPosNegLabeld,
RandFlipd,
Randomizable,
RandRotate90d,
RandRotated,
RandSpatialCropd,
RandSpatialCropSamplesd,
RandWeightedCropd,
RandZoomd,
Resized,
ResizeWithPadOrCrop,
Expand Down Expand Up @@ -440,11 +443,44 @@
)
)

TESTS.append(
(
"RandCropByPosNegLabeld 2d",
"2D",
1e-7,
RandCropByPosNegLabeld(KEYS, "label", (99, 96), num_samples=10),
)
)

TESTS.append(
(
"RandSpatialCropSamplesd 2d",
"2D",
1e-7,
RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=10),
)
)

TESTS.append(
(
"RandWeightedCropd 2d",
"2D",
1e-7,
RandWeightedCropd(KEYS, "label", (90, 91), num_samples=10),
)
)

TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS]

TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore

NUM_SAMPLES = 5
N_SAMPLES_TESTS = [
[RandCropByPosNegLabeld(KEYS, "label", (110, 99), num_samples=NUM_SAMPLES)],
[RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=NUM_SAMPLES, random_size=False)],
[RandWeightedCropd(KEYS, "label", (90, 91), num_samples=NUM_SAMPLES)],
]


def no_collation(x):
return x
Expand Down Expand Up @@ -563,8 +599,15 @@ def test_inverse(self, _, data_name, acceptable_diff, *transforms):
fwd_bck = forwards[-1].copy()
for i, t in enumerate(reversed(transforms)):
if isinstance(t, InvertibleTransform):
fwd_bck = t.inverse(fwd_bck)
self.check_inverse(name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff)
if isinstance(fwd_bck, list):
for j, _fwd_bck in enumerate(fwd_bck):
fwd_bck = t.inverse(_fwd_bck)
self.check_inverse(
name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1][j], acceptable_diff
)
else:
fwd_bck = t.inverse(fwd_bck)
self.check_inverse(name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff)

# skip this test if multiprocessing uses 'spawn', as the check is only basic anyway
@skipUnless(torch.multiprocessing.get_start_method(allow_none=False) == "spawn", "requires spawn")
Expand All @@ -578,7 +621,8 @@ def test_fail(self):
with self.assertRaises(RuntimeError):
t2.inverse(data)

def test_inverse_inferred_seg(self):
@parameterized.expand(N_SAMPLES_TESTS)
def test_inverse_inferred_seg(self, extra_transform):

test_data = []
for _ in range(20):
Expand All @@ -588,7 +632,13 @@ def test_inverse_inferred_seg(self):
batch_size = 10
# num workers = 0 for mac
num_workers = 2 if sys.platform != "darwin" else 0
transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), CenterSpatialCropd(KEYS, (110, 99))])
transforms = Compose(
[
AddChanneld(KEYS),
SpatialPadd(KEYS, (150, 153)),
extra_transform,
]
)
num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform))

dataset = CacheDataset(test_data, transform=transforms, progress=False)
Expand All @@ -604,6 +654,9 @@ def test_inverse_inferred_seg(self):
).to(device)

data = first(loader)
self.assertEqual(len(data["label_transforms"]), num_invertible_transforms)
self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES)

labels = data["label"].to(device)
segs = model(labels).detach().cpu()
label_transform_key = "label" + InverseKeys.KEY_SUFFIX
Expand Down