From d1c33f9cc7834d344b4b62a12b4227f96e1406a2 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 10 Nov 2021 23:40:59 +0800 Subject: [PATCH 1/2] [DLMED] fix keys issue Signed-off-by: Nic Ma --- monai/apps/deepgrow/transforms.py | 12 ++++++++++-- monai/transforms/croppad/dictionary.py | 8 ++++---- monai/transforms/intensity/dictionary.py | 10 +++++----- monai/transforms/spatial/dictionary.py | 16 +++++++++------- 4 files changed, 28 insertions(+), 18 deletions(-) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 8c9eb884dd..80acb881aa 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -19,7 +19,15 @@ from monai.transforms import Resize, SpatialCrop from monai.transforms.transform import MapTransform, Randomizable, Transform from monai.transforms.utils import generate_spatial_bounding_box -from monai.utils import InterpolateMode, deprecated_arg, ensure_tuple, ensure_tuple_rep, min_version, optional_import +from monai.utils import ( + InterpolateMode, + deprecated_arg, + ensure_tuple, + ensure_tuple_rep, + first, + min_version, + optional_import, +) measure, _ = optional_import("skimage.measure", "0.14.2", min_version) distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") @@ -645,7 +653,7 @@ def bounding_box(self, points, img_shape): def __call__(self, data): d: Dict = dict(data) guidance = d[self.guidance] - original_spatial_shape = d[self.keys[0]].shape[1:] + original_spatial_shape = d[first(self.key_iterator(d))].shape[1:] box_start, box_end = self.bounding_box(np.array(guidance[0] + guidance[1]), original_spatial_shape) center = list(np.mean([box_start, box_end], axis=0).astype(int)) spatial_size = self.spatial_size diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 7348451f25..58e40a3e3b 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -51,7 +51,7 @@ weighted_patch_samples, ) from monai.utils import ImageMetaKey as Key -from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple +from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, first from monai.utils.enums import InverseKeys __all__ = [ @@ -481,7 +481,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) # use the spatial size of first image to scale, expect all images have the same spatial size - img_size = data[self.keys[0]].shape[1:] + img_size = d[first(self.key_iterator(d))].shape[1:] ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] cropper = CenterSpatialCrop(roi_size) @@ -575,7 +575,7 @@ def randomize(self, img_size: Sequence[int]) -> None: def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize(d[self.keys[0]].shape[1:]) # image shape from the first data key + self.randomize(d[first(self.key_iterator(d))].shape[1:]) # image shape from the first data key if self._size is None: raise RuntimeError("self._size not specified.") for key in self.key_iterator(d): @@ -669,7 +669,7 @@ def __init__( self.max_roi_scale = max_roi_scale def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - img_size = data[self.keys[0]].shape[1:] + img_size = data[first(self.key_iterator(data))].shape[1:] # type: ignore ndim = len(img_size) self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] if self.max_roi_scale is not None: diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index c33cad85a5..ca047890e8 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -53,7 +53,7 @@ ) from monai.transforms.transform import MapTransform, RandomizableTransform from monai.transforms.utils import is_positive -from monai.utils import ensure_tuple, ensure_tuple_rep +from monai.utils import ensure_tuple, ensure_tuple_rep, first from monai.utils.deprecate_utils import deprecated_arg __all__ = [ @@ -187,7 +187,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d # all the keys share the same random noise - self.rand_gaussian_noise.randomize(d[self.keys[0]]) + self.rand_gaussian_noise.randomize(d[first(self.key_iterator(d))]) for key in self.key_iterator(d): d[key] = self.rand_gaussian_noise(img=d[key], randomize=False) return d @@ -621,7 +621,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d # all the keys share the same random bias factor - self.rand_bias_field.randomize(img_size=d[self.keys[0]].shape[1:]) + self.rand_bias_field.randomize(img_size=d[first(self.key_iterator(d))].shape[1:]) for key in self.key_iterator(d): d[key] = self.rand_bias_field(d[key], randomize=False) return d @@ -1466,7 +1466,7 @@ def __call__(self, data): return d # expect all the specified keys have same spatial shape and share same random holes - self.dropper.randomize(d[self.keys[0]].shape[1:]) + self.dropper.randomize(d[first(self.key_iterator(d))].shape[1:]) for key in self.key_iterator(d): d[key] = self.dropper(img=d[key], randomize=False) @@ -1531,7 +1531,7 @@ def __call__(self, data): return d # expect all the specified keys have same spatial shape and share same random holes - self.shuffle.randomize(d[self.keys[0]].shape[1:]) + self.shuffle.randomize(d[first(self.key_iterator(d))].shape[1:]) for key in self.key_iterator(d): d[key] = self.shuffle(img=d[key], randomize=False) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 4baf35c569..f8813f5ece 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -60,6 +60,7 @@ ensure_tuple, ensure_tuple_rep, fall_back_tuple, + first, ) from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import InverseKeys @@ -815,9 +816,10 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N device = self.rand_affine.resampler.device - sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) + spatial_size = d[first(self.key_iterator(d))].shape[1:] + sp_size = fall_back_tuple(self.rand_affine.spatial_size, spatial_size) # change image size or do random transform - do_resampling = self._do_transform or (sp_size != ensure_tuple(data[self.keys[0]].shape[1:])) + do_resampling = self._do_transform or (sp_size != ensure_tuple(spatial_size)) affine: torch.Tensor = torch.eye(len(sp_size) + 1, dtype=torch.float64, device=device) # converting affine to tensor because the resampler currently only support torch backend grid = None @@ -975,7 +977,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N d = dict(data) self.randomize(None) - sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, data[self.keys[0]].shape[1:]) + sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, d[first(self.key_iterator(d))].shape[1:]) # all the keys share the same random elastic factor self.rand_2d_elastic.randomize(sp_size) @@ -1107,7 +1109,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N d = dict(data) self.randomize(None) - sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, data[self.keys[0]].shape[1:]) + sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, d[first(self.key_iterator(d))].shape[1:]) # all the keys share the same random elastic factor self.rand_3d_elastic.randomize(sp_size) @@ -1258,7 +1260,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N self.randomize(None) # all the keys share the same random selected axis - self.flipper.randomize(d[self.keys[0]]) + self.flipper.randomize(d[first(self.key_iterator(d))]) for key in self.key_iterator(d): if self._do_transform: d[key] = self.flipper(d[key], randomize=False) @@ -1682,7 +1684,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N self.randomize(None) # all the keys share the same random zoom factor - self.rand_zoom.randomize(d[self.keys[0]]) + self.rand_zoom.randomize(d[first(self.key_iterator(d))]) for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): @@ -1864,7 +1866,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if not self._do_transform: return d - self.rand_grid_distortion.randomize(d[self.keys[0]].shape[1:]) + self.rand_grid_distortion.randomize(d[first(self.key_iterator(d))].shape[1:]) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): d[key] = self.rand_grid_distortion(d[key], mode=mode, padding_mode=padding_mode, randomize=False) return d From 737c6b8fe843eb06b7a3dad9955513bb03b65852 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 10 Nov 2021 23:51:10 +0800 Subject: [PATCH 2/2] [DLMED] add unit test Signed-off-by: Nic Ma --- tests/test_rand_scale_cropd.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_rand_scale_cropd.py b/tests/test_rand_scale_cropd.py index f78a81d339..3dc4578f62 100644 --- a/tests/test_rand_scale_cropd.py +++ b/tests/test_rand_scale_cropd.py @@ -24,7 +24,8 @@ ] TEST_CASE_2 = [ - {"keys": "img", "roi_scale": [1.0, 1.0, 1.0], "random_center": False}, + # test `allow_missing_keys` with key "label" + {"keys": ["label", "img"], "roi_scale": [1.0, 1.0, 1.0], "random_center": False, "allow_missing_keys": True}, {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, (3, 3, 3, 3), ]