diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 4ec14e0a7d..d3aae960bc 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -15,6 +15,7 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ +import contextlib from copy import deepcopy from enum import Enum from itertools import chain @@ -38,6 +39,7 @@ from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.utils import ( + allow_missing_keys_mode, generate_pos_neg_label_crop_centers, is_positive, map_binary_to_indices, @@ -586,7 +588,15 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return super().__call__(data=data) -class RandSpatialCropSamplesd(Randomizable, MapTransform): +@contextlib.contextmanager +def _nullcontext(x): + """ + This is just like contextlib.nullcontext but also works in Python 3.6. + """ + yield x + + +class RandSpatialCropSamplesd(Randomizable, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandSpatialCropSamples`. Crop image with random size or specific size ROI to generate a list of N samples. @@ -664,6 +674,10 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n for key in set(data.keys()).difference(set(self.keys)): d[key] = deepcopy(data[key]) cropped = self.cropper(d) + # self.cropper will have added RandSpatialCropd to the list. Change to RandSpatialCropSamplesd + for key in self.key_iterator(cropped): + cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.CLASS_NAME] = self.__class__.__name__ + cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.ID] = id(self) # add `patch_index` to the meta data for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): meta_key = meta_key or f"{key}_{meta_key_postfix}" @@ -673,6 +687,17 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n ret.append(cropped) return ret + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + # We changed the transform name from RandSpatialCropd to RandSpatialCropSamplesd + # Need to revert that since we're calling RandSpatialCropd's inverse + for key in self.key_iterator(d): + d[key + InverseKeys.KEY_SUFFIX][-1][InverseKeys.CLASS_NAME] = self.cropper.__class__.__name__ + d[key + InverseKeys.KEY_SUFFIX][-1][InverseKeys.ID] = id(self.cropper) + context_manager = allow_missing_keys_mode if self.allow_missing_keys else _nullcontext + with context_manager(self.cropper): + return self.cropper.inverse(d) + class CropForegroundd(MapTransform, InvertibleTransform): """ @@ -770,7 +795,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar return d -class RandWeightedCropd(Randomizable, MapTransform): +class RandWeightedCropd(Randomizable, MapTransform, InvertibleTransform): """ Samples a list of `num_samples` image patches according to the provided `weight_map`. @@ -831,6 +856,10 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n _spatial_size = fall_back_tuple(self.spatial_size, d[self.w_key].shape[1:]) results: List[Dict[Hashable, np.ndarray]] = [{} for _ in range(self.num_samples)] + # fill in the extra keys with unmodified data + for i in range(self.num_samples): + for key in set(data.keys()).difference(set(self.keys)): + results[i][key] = deepcopy(data[key]) for key in self.key_iterator(d): img = d[key] if img.shape[1:] != d[self.w_key].shape[1:]: @@ -840,13 +869,13 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n ) for i, center in enumerate(self.centers): cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size) + orig_size = img.shape[1:] results[i][key] = cropper(img) + self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) if self.center_coord_key: results[i][self.center_coord_key] = center # fill in the extra keys with unmodified data for i in range(self.num_samples): - for key in set(data.keys()).difference(set(self.keys)): - results[i][key] = deepcopy(data[key]) # add `patch_index` to the meta data for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): meta_key = meta_key or f"{key}_{meta_key_postfix}" @@ -856,8 +885,30 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n return results + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE]) + current_size = np.asarray(d[key].shape[1:]) + center = transform[InverseKeys.EXTRA_INFO]["center"] + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) + # get required pad to start and end + pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) + pad_to_end = orig_size - current_size - pad_to_start + # interleave mins and maxes + pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) + inverse_transform = BorderPad(pad) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + + return d + -class RandCropByPosNegLabeld(Randomizable, MapTransform): +class RandCropByPosNegLabeld(Randomizable, MapTransform, InvertibleTransform): """ Dictionary-based version :py:class:`monai.transforms.RandCropByPosNegLabel`. Crop random fixed sized regions with the center being a foreground or background voxel @@ -973,13 +1024,15 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n results: List[Dict[Hashable, np.ndarray]] = [{} for _ in range(self.num_samples)] for i, center in enumerate(self.centers): + # fill in the extra keys with unmodified data + for key in set(data.keys()).difference(set(self.keys)): + results[i][key] = deepcopy(data[key]) for key in self.key_iterator(d): img = d[key] cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + orig_size = img.shape[1:] results[i][key] = cropper(img) - # fill in the extra keys with unmodified data - for key in set(data.keys()).difference(set(self.keys)): - results[i][key] = deepcopy(data[key]) + self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) # add `patch_index` to the meta data for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): meta_key = meta_key or f"{key}_{meta_key_postfix}" @@ -989,6 +1042,28 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n return results + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE]) + current_size = np.asarray(d[key].shape[1:]) + center = transform[InverseKeys.EXTRA_INFO]["center"] + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + # get required pad to start and end + pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) + pad_to_end = orig_size - current_size - pad_to_start + # interleave mins and maxes + pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) + inverse_transform = BorderPad(pad) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + + return d + class ResizeWithPadOrCropd(MapTransform, InvertibleTransform): """ diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 5dbe5b833f..f1ce314b01 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -38,11 +38,14 @@ Orientationd, RandAffined, RandAxisFlipd, + RandCropByPosNegLabeld, RandFlipd, Randomizable, RandRotate90d, RandRotated, RandSpatialCropd, + RandSpatialCropSamplesd, + RandWeightedCropd, RandZoomd, Resized, ResizeWithPadOrCrop, @@ -440,11 +443,44 @@ ) ) +TESTS.append( + ( + "RandCropByPosNegLabeld 2d", + "2D", + 1e-7, + RandCropByPosNegLabeld(KEYS, "label", (99, 96), num_samples=10), + ) +) + +TESTS.append( + ( + "RandSpatialCropSamplesd 2d", + "2D", + 1e-7, + RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=10), + ) +) + +TESTS.append( + ( + "RandWeightedCropd 2d", + "2D", + 1e-7, + RandWeightedCropd(KEYS, "label", (90, 91), num_samples=10), + ) +) TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS] TESTS = TESTS + TESTS_COMPOSE_X2 # type: ignore +NUM_SAMPLES = 5 +N_SAMPLES_TESTS = [ + [RandCropByPosNegLabeld(KEYS, "label", (110, 99), num_samples=NUM_SAMPLES)], + [RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=NUM_SAMPLES, random_size=False)], + [RandWeightedCropd(KEYS, "label", (90, 91), num_samples=NUM_SAMPLES)], +] + def no_collation(x): return x @@ -563,8 +599,15 @@ def test_inverse(self, _, data_name, acceptable_diff, *transforms): fwd_bck = forwards[-1].copy() for i, t in enumerate(reversed(transforms)): if isinstance(t, InvertibleTransform): - fwd_bck = t.inverse(fwd_bck) - self.check_inverse(name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff) + if isinstance(fwd_bck, list): + for j, _fwd_bck in enumerate(fwd_bck): + fwd_bck = t.inverse(_fwd_bck) + self.check_inverse( + name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1][j], acceptable_diff + ) + else: + fwd_bck = t.inverse(fwd_bck) + self.check_inverse(name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff) # skip this test if multiprocessing uses 'spawn', as the check is only basic anyway @skipUnless(torch.multiprocessing.get_start_method(allow_none=False) == "spawn", "requires spawn") @@ -578,7 +621,8 @@ def test_fail(self): with self.assertRaises(RuntimeError): t2.inverse(data) - def test_inverse_inferred_seg(self): + @parameterized.expand(N_SAMPLES_TESTS) + def test_inverse_inferred_seg(self, extra_transform): test_data = [] for _ in range(20): @@ -588,7 +632,13 @@ def test_inverse_inferred_seg(self): batch_size = 10 # num workers = 0 for mac num_workers = 2 if sys.platform != "darwin" else 0 - transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), CenterSpatialCropd(KEYS, (110, 99))]) + transforms = Compose( + [ + AddChanneld(KEYS), + SpatialPadd(KEYS, (150, 153)), + extra_transform, + ] + ) num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) dataset = CacheDataset(test_data, transform=transforms, progress=False) @@ -604,6 +654,9 @@ def test_inverse_inferred_seg(self): ).to(device) data = first(loader) + self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) + self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES) + labels = data["label"].to(device) segs = model(labels).detach().cpu() label_transform_key = "label" + InverseKeys.KEY_SUFFIX