diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index d6472ddeca..90ba2d601a 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -664,6 +664,9 @@ class RandWeightedCropd(Randomizable, MapTransform): If its components have non-positive values, the corresponding size of `img` will be used. num_samples: number of samples (image patches) to take in the returned list. center_coord_key: if specified, the actual sampling location will be stored with the corresponding key. + meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, + default is `meta_dict`, the meta data is a dictionary object. + used to add `patch_index` to the meta dict. allow_missing_keys: don't raise exception if key is missing. See Also: @@ -677,6 +680,7 @@ def __init__( spatial_size: Union[Sequence[int], int], num_samples: int = 1, center_coord_key: Optional[str] = None, + meta_key_postfix: str = "meta_dict", allow_missing_keys: bool = False, ): MapTransform.__init__(self, keys, allow_missing_keys) @@ -684,6 +688,7 @@ def __init__( self.w_key = w_key self.num_samples = int(num_samples) self.center_coord_key = center_coord_key + self.meta_key_postfix = meta_key_postfix self.centers: List[np.ndarray] = [] def randomize(self, weight_map: np.ndarray) -> None: @@ -710,9 +715,15 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n if self.center_coord_key: results[i][self.center_coord_key] = center # fill in the extra keys with unmodified data - for key in set(data.keys()).difference(set(self.keys)): - for i in range(self.num_samples): - results[i][key] = data[key] + for i in range(self.num_samples): + for key in set(data.keys()).difference(set(self.keys)): + results[i][key] = deepcopy(data[key]) + # add `patch_index` to the meta data + for key in self.key_iterator(d): + meta_data_key = f"{key}_{self.meta_key_postfix}" + if meta_data_key not in results[i]: + results[i][meta_data_key] = {} # type: ignore + results[i][meta_data_key][Key.PATCH_INDEX] = i return results @@ -829,7 +840,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n results[i][key] = cropper(img) # fill in the extra keys with unmodified data for key in set(data.keys()).difference(set(self.keys)): - results[i][key] = data[key] + results[i][key] = deepcopy(data[key]) # add `patch_index` to the meta data for key in self.key_iterator(d): meta_data_key = f"{key}_{self.meta_key_postfix}" diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index d52ba900ac..91bc6724ee 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -18,7 +18,7 @@ TEST_CASE_0 = [ { - "keys": ["image", "extral", "label"], + "keys": ["image", "extra", "label"], "label_key": "label", "spatial_size": [-1, 2, 2], "pos": 1, @@ -29,10 +29,9 @@ }, { "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "extral": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "extra": np.random.randint(0, 2, size=[3, 3, 3, 3]), "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "affine": np.eye(3), - "shape": "CHWD", + "image_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, }, list, (3, 3, 2, 2), @@ -40,7 +39,7 @@ TEST_CASE_1 = [ { - "keys": ["image", "extral", "label"], + "keys": ["image", "extra", "label"], "label_key": "label", "spatial_size": [2, 2, 2], "pos": 1, @@ -51,10 +50,9 @@ }, { "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "extral": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "extra": np.random.randint(0, 2, size=[3, 3, 3, 3]), "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), - "affine": np.eye(3), - "shape": "CHWD", + "label_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, }, list, (3, 2, 2, 2), @@ -62,7 +60,7 @@ TEST_CASE_2 = [ { - "keys": ["image", "extral", "label"], + "keys": ["image", "extra", "label"], "label_key": "label", "spatial_size": [2, 2, 2], "pos": 1, @@ -73,10 +71,9 @@ }, { "image": np.zeros([3, 3, 3, 3]) - 1, - "extral": np.zeros([3, 3, 3, 3]), + "extra": np.zeros([3, 3, 3, 3]), "label": np.ones([3, 3, 3, 3]), - "affine": np.eye(3), - "shape": "CHWD", + "extra_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, }, list, (3, 2, 2, 2), @@ -89,10 +86,12 @@ def test_type_shape(self, input_param, input_data, expected_type, expected_shape result = RandCropByPosNegLabeld(**input_param)(input_data) self.assertIsInstance(result, expected_type) self.assertTupleEqual(result[0]["image"].shape, expected_shape) - self.assertTupleEqual(result[0]["extral"].shape, expected_shape) + self.assertTupleEqual(result[0]["extra"].shape, expected_shape) self.assertTupleEqual(result[0]["label"].shape, expected_shape) for i, item in enumerate(result): self.assertEqual(item["image_meta_dict"]["patch_index"], i) + self.assertEqual(item["label_meta_dict"]["patch_index"], i) + self.assertEqual(item["extra_meta_dict"]["patch_index"], i) if __name__ == "__main__": diff --git a/tests/test_rand_weighted_cropd.py b/tests/test_rand_weighted_cropd.py index 0edb1d732d..367ce3beb9 100644 --- a/tests/test_rand_weighted_cropd.py +++ b/tests/test_rand_weighted_cropd.py @@ -139,6 +139,24 @@ def test_rand_weighted_crop_bad_w(self): np.testing.assert_allclose(result[0]["seg"].shape, (1, 48, 64, 80)) np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]]) + def test_rand_weighted_crop_patch_index(self): + img = self.imt[0] + n_samples = 3 + crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples) + weight = np.zeros_like(img) + weight[0, 7, 17] = 1.1 + weight[0, 13, 31] = 1.1 + weight[0, 24, 21] = 1 + crop.set_random_state(10) + result = crop({"img": img, "seg": self.segn[0], "w": weight, "img_meta_dict": {"affine": None}}) + self.assertTrue(len(result) == n_samples) + np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32, 40], [41, 32, 40], [20, 32, 40]]) + for i in range(n_samples): + np.testing.assert_allclose(result[i]["img"].shape, (1, 10, 64, 80)) + np.testing.assert_allclose(result[i]["seg"].shape, (1, 10, 64, 80)) + np.testing.assert_allclose(result[i]["img_meta_dict"]["patch_index"], i) + np.testing.assert_allclose(result[i]["seg_meta_dict"]["patch_index"], i) + if __name__ == "__main__": unittest.main()