Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def __init__(
roi_start: Union[Sequence[int], NdarrayOrTensor, None] = None,
roi_end: Union[Sequence[int], NdarrayOrTensor, None] = None,
roi_slices: Optional[Sequence[slice]] = None,
allow_smaller: bool = False,
) -> None:
"""
Args:
Expand All @@ -393,6 +394,9 @@ def __init__(
roi_end: voxel coordinates for end of the crop ROI, if a coordinate is out of image,
use the end coordinate of image.
roi_slices: list of slices for each of the spatial dimensions.
allow_smaller: if `False`, an exception will be raised if the image is smaller than
the requested ROI in any dimension. If `True`, any smaller dimensions will remain
unchanged.
"""
roi_start_torch: torch.Tensor

Expand Down Expand Up @@ -880,6 +884,9 @@ class RandCropByPosNegLabel(Randomizable, Transform):
`image_threshold`, and randomly select crop centers based on them, need to provide `fg_indices`
and `bg_indices` together, expect to be 1 dim array of spatial indices after flattening.
a typical usage is to call `FgBgToIndices` transform first and cache the results.
allow_smaller: if `False`, an exception will be raised if the image is smaller than
the requested ROI in any dimension. If `True`, any smaller dimensions will be set to
match the cropped size (i.e., no cropping in that dimension).

Raises:
ValueError: When ``pos`` or ``neg`` are negative.
Expand All @@ -900,6 +907,7 @@ def __init__(
image_threshold: float = 0.0,
fg_indices: Optional[NdarrayOrTensor] = None,
bg_indices: Optional[NdarrayOrTensor] = None,
allow_smaller: bool = False,
) -> None:
self.spatial_size = ensure_tuple(spatial_size)
self.label = label
Expand All @@ -914,6 +922,7 @@ def __init__(
self.centers: Optional[List[List[int]]] = None
self.fg_indices = fg_indices
self.bg_indices = bg_indices
self.allow_smaller = allow_smaller

def randomize(
self,
Expand All @@ -933,7 +942,14 @@ def randomize(
fg_indices_ = fg_indices
bg_indices_ = bg_indices
self.centers = generate_pos_neg_label_crop_centers(
self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R
self.spatial_size,
self.num_samples,
self.pos_ratio,
label.shape[1:],
fg_indices_,
bg_indices_,
self.R,
self.allow_smaller,
)

def __call__(
Expand Down Expand Up @@ -1031,6 +1047,9 @@ class RandCropByLabelClasses(Randomizable, Transform):
`image_threshold`, and randomly select crop centers based on them, expect to be 1 dim array
of spatial indices after flattening. a typical usage is to call `ClassesToIndices` transform first
and cache the results for better performance.
allow_smaller: if `False`, an exception will be raised if the image is smaller than
the requested ROI in any dimension. If `True`, any smaller dimensions will remain
unchanged.

"""

Expand All @@ -1046,6 +1065,7 @@ def __init__(
image: Optional[NdarrayOrTensor] = None,
image_threshold: float = 0.0,
indices: Optional[List[NdarrayOrTensor]] = None,
allow_smaller: bool = False,
) -> None:
self.spatial_size = ensure_tuple(spatial_size)
self.ratios = ratios
Expand All @@ -1056,6 +1076,7 @@ def __init__(
self.image_threshold = image_threshold
self.centers: Optional[List[List[int]]] = None
self.indices = indices
self.allow_smaller = allow_smaller

def randomize(
self,
Expand All @@ -1073,7 +1094,7 @@ def randomize(
else:
indices_ = indices
self.centers = generate_label_classes_crop_centers(
self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R
self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R, self.allow_smaller
)

def __call__(
Expand Down
21 changes: 19 additions & 2 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,9 @@ class RandCropByPosNegLabeld(Randomizable, MapTransform, InvertibleTransform):
meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according
to the key data, default is `meta_dict`, the meta data is a dictionary object.
used to add `patch_index` to the meta dict.
allow_smaller: if `False`, an exception will be raised if the image is smaller than
the requested ROI in any dimension. If `True`, any smaller dimensions will be set to
match the cropped size (i.e., no cropping in that dimension).
allow_missing_keys: don't raise exception if key is missing.

Raises:
Expand All @@ -1089,6 +1092,7 @@ def __init__(
bg_indices_key: Optional[str] = None,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
allow_smaller: bool = False,
allow_missing_keys: bool = False,
) -> None:
MapTransform.__init__(self, keys, allow_missing_keys)
Expand All @@ -1109,6 +1113,7 @@ def __init__(
raise ValueError("meta_keys should have the same length as keys.")
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
self.centers: Optional[List[List[int]]] = None
self.allow_smaller = allow_smaller

def randomize(
self,
Expand All @@ -1124,7 +1129,14 @@ def randomize(
fg_indices_ = fg_indices
bg_indices_ = bg_indices
self.centers = generate_pos_neg_label_crop_centers(
self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R
self.spatial_size,
self.num_samples,
self.pos_ratio,
label.shape[1:],
fg_indices_,
bg_indices_,
self.R,
self.allow_smaller,
)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]:
Expand Down Expand Up @@ -1257,6 +1269,9 @@ class RandCropByLabelClassesd(Randomizable, MapTransform, InvertibleTransform):
meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according
to the key data, default is `meta_dict`, the meta data is a dictionary object.
used to add `patch_index` to the meta dict.
allow_smaller: if `False`, an exception will be raised if the image is smaller than
the requested ROI in any dimension. If `True`, any smaller dimensions will remain
unchanged.
allow_missing_keys: don't raise exception if key is missing.

"""
Expand All @@ -1276,6 +1291,7 @@ def __init__(
indices_key: Optional[str] = None,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
allow_smaller: bool = False,
allow_missing_keys: bool = False,
) -> None:
MapTransform.__init__(self, keys, allow_missing_keys)
Expand All @@ -1292,6 +1308,7 @@ def __init__(
raise ValueError("meta_keys should have the same length as keys.")
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
self.centers: Optional[List[List[int]]] = None
self.allow_smaller = allow_smaller

def randomize(
self,
Expand All @@ -1305,7 +1322,7 @@ def randomize(
else:
indices_ = indices
self.centers = generate_label_classes_crop_centers(
self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R
self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R, self.allow_smaller
)

def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, NdarrayOrTensor]]:
Expand Down
22 changes: 18 additions & 4 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def correct_crop_centers(
centers: List[Union[int, torch.Tensor]],
spatial_size: Union[Sequence[int], int],
label_spatial_shape: Sequence[int],
allow_smaller: bool = False,
):
"""
Utility to correct the crop center if the crop size is bigger than the image size.
Expand All @@ -414,11 +415,16 @@ def correct_crop_centers(
centers: pre-computed crop centers of every dim, will correct based on the valid region.
spatial_size: spatial size of the ROIs to be sampled.
label_spatial_shape: spatial shape of the original label data to compare with ROI.
allow_smaller: if `False`, an exception will be raised if the image is smaller than
the requested ROI in any dimension. If `True`, any smaller dimensions will be set to
match the cropped size (i.e., no cropping in that dimension).

"""
spatial_size = fall_back_tuple(spatial_size, default=label_spatial_shape)
if not (np.subtract(label_spatial_shape, spatial_size) >= 0).all():
raise ValueError("The size of the proposed random crop ROI is larger than the image size.")
if any(np.subtract(label_spatial_shape, spatial_size) < 0):
if not allow_smaller:
raise ValueError("The size of the proposed random crop ROI is larger than the image size.")
spatial_size = tuple(min(l, s) for l, s in zip(label_spatial_shape, spatial_size))

# Select subregion to assure valid roi
valid_start = np.floor_divide(spatial_size, 2)
Expand Down Expand Up @@ -450,6 +456,7 @@ def generate_pos_neg_label_crop_centers(
fg_indices: NdarrayOrTensor,
bg_indices: NdarrayOrTensor,
rand_state: Optional[np.random.RandomState] = None,
allow_smaller: bool = False,
) -> List[List[int]]:
"""
Generate valid sample locations based on the label with option for specifying foreground ratio
Expand All @@ -463,6 +470,9 @@ def generate_pos_neg_label_crop_centers(
fg_indices: pre-computed foreground indices in 1 dimension.
bg_indices: pre-computed background indices in 1 dimension.
rand_state: numpy randomState object to align with other modules.
allow_smaller: if `False`, an exception will be raised if the image is smaller than
the requested ROI in any dimension. If `True`, any smaller dimensions will be set to
match the cropped size (i.e., no cropping in that dimension).

Raises:
ValueError: When the proposed roi is larger than the image.
Expand Down Expand Up @@ -491,7 +501,7 @@ def generate_pos_neg_label_crop_centers(
idx = indices_to_use[random_int]
center = unravel_index(idx, label_spatial_shape)
# shift center to range of valid centers
centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape))
centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller))

return centers

Expand All @@ -503,6 +513,7 @@ def generate_label_classes_crop_centers(
indices: Sequence[NdarrayOrTensor],
ratios: Optional[List[Union[float, int]]] = None,
rand_state: Optional[np.random.RandomState] = None,
allow_smaller: bool = False,
) -> List[List[int]]:
"""
Generate valid sample locations based on the specified ratios of label classes.
Expand All @@ -516,6 +527,9 @@ def generate_label_classes_crop_centers(
ratios: ratios of every class in the label to generate crop centers, including background class.
if None, every class will have the same ratio to generate crop centers.
rand_state: numpy randomState object to align with other modules.
allow_smaller: if `False`, an exception will be raised if the image is smaller than
the requested ROI in any dimension. If `True`, any smaller dimensions will be set to
match the cropped size (i.e., no cropping in that dimension).

"""
if rand_state is None:
Expand Down Expand Up @@ -543,7 +557,7 @@ def generate_label_classes_crop_centers(
center = unravel_index(indices_to_use[random_int], label_spatial_shape)
# shift center to range of valid centers
center_ori = list(center)
centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape))
centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape, allow_smaller))

return centers

Expand Down
44 changes: 44 additions & 0 deletions tests/test_rand_crop_by_label_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,50 @@
(3, 2, 2, 2),
]
)
TESTS_SHAPE.append(
[
# provide label at runtime
{
"label": None,
"num_classes": 2,
"spatial_size": [4, 4, 2],
"ratios": [1, 1],
"num_samples": 2,
"image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),
"image_threshold": 0,
"allow_smaller": True,
},
{
"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),
"label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])),
"image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),
},
list,
(3, 3, 3, 2),
]
)
TESTS_SHAPE.append(
[
# provide label at runtime
{
"label": None,
"num_classes": 2,
"spatial_size": [4, 4, 4],
"ratios": [1, 1],
"num_samples": 2,
"image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),
"image_threshold": 0,
"allow_smaller": True,
},
{
"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),
"label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])),
"image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),
},
list,
(3, 3, 3, 3),
]
)


class TestRandCropByLabelClasses(unittest.TestCase):
Expand Down
48 changes: 48 additions & 0 deletions tests/test_rand_crop_by_label_classesd.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,54 @@
]
)

TESTS.append(
[
# Argmax label
{
"keys": "img",
"label_key": "label",
"num_classes": 2,
"spatial_size": [4, 4, 2],
"ratios": [1, 1],
"num_samples": 2,
"image_key": "image",
"image_threshold": 0,
"allow_smaller": True,
},
{
"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),
"image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),
"label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])),
},
list,
(3, 3, 3, 2),
]
)

TESTS.append(
[
# Argmax label
{
"keys": "img",
"label_key": "label",
"num_classes": 2,
"spatial_size": [4, 4, 4],
"ratios": [1, 1],
"num_samples": 2,
"image_key": "image",
"image_threshold": 0,
"allow_smaller": True,
},
{
"img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),
"image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),
"label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])),
},
list,
(3, 3, 3, 3),
]
)


class TestRandCropByLabelClassesd(unittest.TestCase):
@parameterized.expand(TESTS)
Expand Down
30 changes: 30 additions & 0 deletions tests/test_rand_crop_by_pos_neg_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,36 @@
(3, 2, 2, 2),
]
)
TESTS.append(
[
{
"label": np.random.randint(0, 2, size=[3, 3, 3, 3]),
"spatial_size": [4, 4, 2],
"pos": 1,
"neg": 1,
"num_samples": 2,
"image": np.random.randint(0, 2, size=[3, 3, 3, 3]),
"allow_smaller": True,
},
{"img": np.random.randint(0, 2, size=[3, 3, 3, 3])},
(3, 3, 3, 2),
]
)
TESTS.append(
[
{
"label": np.random.randint(0, 2, size=[3, 3, 3, 3]),
"spatial_size": [4, 4, 4],
"pos": 1,
"neg": 1,
"num_samples": 2,
"image": np.random.randint(0, 2, size=[3, 3, 3, 3]),
"allow_smaller": True,
},
{"img": np.random.randint(0, 2, size=[3, 3, 3, 3])},
(3, 3, 3, 3),
]
)


class TestRandCropByPosNegLabel(unittest.TestCase):
Expand Down
Loading