Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,9 @@ class RandWeightedCropd(Randomizable, MapTransform):
If its components have non-positive values, the corresponding size of `img` will be used.
num_samples: number of samples (image patches) to take in the returned list.
center_coord_key: if specified, the actual sampling location will be stored with the corresponding key.
meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data,
default is `meta_dict`, the meta data is a dictionary object.
used to add `patch_index` to the meta dict.
allow_missing_keys: don't raise exception if key is missing.

See Also:
Expand All @@ -677,13 +680,15 @@ def __init__(
spatial_size: Union[Sequence[int], int],
num_samples: int = 1,
center_coord_key: Optional[str] = None,
meta_key_postfix: str = "meta_dict",
allow_missing_keys: bool = False,
):
MapTransform.__init__(self, keys, allow_missing_keys)
self.spatial_size = ensure_tuple(spatial_size)
self.w_key = w_key
self.num_samples = int(num_samples)
self.center_coord_key = center_coord_key
self.meta_key_postfix = meta_key_postfix
self.centers: List[np.ndarray] = []

def randomize(self, weight_map: np.ndarray) -> None:
Expand All @@ -710,9 +715,15 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n
if self.center_coord_key:
results[i][self.center_coord_key] = center
# fill in the extra keys with unmodified data
for key in set(data.keys()).difference(set(self.keys)):
for i in range(self.num_samples):
results[i][key] = data[key]
for i in range(self.num_samples):
for key in set(data.keys()).difference(set(self.keys)):
results[i][key] = deepcopy(data[key])
# add `patch_index` to the meta data
for key in self.key_iterator(d):
meta_data_key = f"{key}_{self.meta_key_postfix}"
if meta_data_key not in results[i]:
results[i][meta_data_key] = {} # type: ignore
results[i][meta_data_key][Key.PATCH_INDEX] = i

return results

Expand Down Expand Up @@ -829,7 +840,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n
results[i][key] = cropper(img)
# fill in the extra keys with unmodified data
for key in set(data.keys()).difference(set(self.keys)):
results[i][key] = data[key]
results[i][key] = deepcopy(data[key])
# add `patch_index` to the meta data
for key in self.key_iterator(d):
meta_data_key = f"{key}_{self.meta_key_postfix}"
Expand Down
25 changes: 12 additions & 13 deletions tests/test_rand_crop_by_pos_neg_labeld.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

TEST_CASE_0 = [
{
"keys": ["image", "extral", "label"],
"keys": ["image", "extra", "label"],
"label_key": "label",
"spatial_size": [-1, 2, 2],
"pos": 1,
Expand All @@ -29,18 +29,17 @@
},
{
"image": np.random.randint(0, 2, size=[3, 3, 3, 3]),
"extral": np.random.randint(0, 2, size=[3, 3, 3, 3]),
"extra": np.random.randint(0, 2, size=[3, 3, 3, 3]),
"label": np.random.randint(0, 2, size=[3, 3, 3, 3]),
"affine": np.eye(3),
"shape": "CHWD",
"image_meta_dict": {"affine": np.eye(3), "shape": "CHWD"},
},
list,
(3, 3, 2, 2),
]

TEST_CASE_1 = [
{
"keys": ["image", "extral", "label"],
"keys": ["image", "extra", "label"],
"label_key": "label",
"spatial_size": [2, 2, 2],
"pos": 1,
Expand All @@ -51,18 +50,17 @@
},
{
"image": np.random.randint(0, 2, size=[3, 3, 3, 3]),
"extral": np.random.randint(0, 2, size=[3, 3, 3, 3]),
"extra": np.random.randint(0, 2, size=[3, 3, 3, 3]),
"label": np.random.randint(0, 2, size=[3, 3, 3, 3]),
"affine": np.eye(3),
"shape": "CHWD",
"label_meta_dict": {"affine": np.eye(3), "shape": "CHWD"},
},
list,
(3, 2, 2, 2),
]

TEST_CASE_2 = [
{
"keys": ["image", "extral", "label"],
"keys": ["image", "extra", "label"],
"label_key": "label",
"spatial_size": [2, 2, 2],
"pos": 1,
Expand All @@ -73,10 +71,9 @@
},
{
"image": np.zeros([3, 3, 3, 3]) - 1,
"extral": np.zeros([3, 3, 3, 3]),
"extra": np.zeros([3, 3, 3, 3]),
"label": np.ones([3, 3, 3, 3]),
"affine": np.eye(3),
"shape": "CHWD",
"extra_meta_dict": {"affine": np.eye(3), "shape": "CHWD"},
},
list,
(3, 2, 2, 2),
Expand All @@ -89,10 +86,12 @@ def test_type_shape(self, input_param, input_data, expected_type, expected_shape
result = RandCropByPosNegLabeld(**input_param)(input_data)
self.assertIsInstance(result, expected_type)
self.assertTupleEqual(result[0]["image"].shape, expected_shape)
self.assertTupleEqual(result[0]["extral"].shape, expected_shape)
self.assertTupleEqual(result[0]["extra"].shape, expected_shape)
self.assertTupleEqual(result[0]["label"].shape, expected_shape)
for i, item in enumerate(result):
self.assertEqual(item["image_meta_dict"]["patch_index"], i)
self.assertEqual(item["label_meta_dict"]["patch_index"], i)
self.assertEqual(item["extra_meta_dict"]["patch_index"], i)


if __name__ == "__main__":
Expand Down
18 changes: 18 additions & 0 deletions tests/test_rand_weighted_cropd.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,24 @@ def test_rand_weighted_crop_bad_w(self):
np.testing.assert_allclose(result[0]["seg"].shape, (1, 48, 64, 80))
np.testing.assert_allclose(np.asarray(crop.centers), [[24, 32, 40], [24, 32, 40], [24, 32, 40]])

def test_rand_weighted_crop_patch_index(self):
img = self.imt[0]
n_samples = 3
crop = RandWeightedCropd(("img", "seg"), "w", (10, -1, -1), n_samples)
weight = np.zeros_like(img)
weight[0, 7, 17] = 1.1
weight[0, 13, 31] = 1.1
weight[0, 24, 21] = 1
crop.set_random_state(10)
result = crop({"img": img, "seg": self.segn[0], "w": weight, "img_meta_dict": {"affine": None}})
self.assertTrue(len(result) == n_samples)
np.testing.assert_allclose(np.asarray(crop.centers), [[14, 32, 40], [41, 32, 40], [20, 32, 40]])
for i in range(n_samples):
Comment thread
Nic-Ma marked this conversation as resolved.
np.testing.assert_allclose(result[i]["img"].shape, (1, 10, 64, 80))
np.testing.assert_allclose(result[i]["seg"].shape, (1, 10, 64, 80))
np.testing.assert_allclose(result[i]["img_meta_dict"]["patch_index"], i)
np.testing.assert_allclose(result[i]["seg_meta_dict"]["patch_index"], i)


if __name__ == "__main__":
unittest.main()