diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index ce0105bbe1..5c2fe7475a 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -383,6 +383,7 @@ def __init__( roi_start: Union[Sequence[int], NdarrayOrTensor, None] = None, roi_end: Union[Sequence[int], NdarrayOrTensor, None] = None, roi_slices: Optional[Sequence[slice]] = None, + allow_smaller: bool = False, ) -> None: """ Args: @@ -393,6 +394,9 @@ def __init__( roi_end: voxel coordinates for end of the crop ROI, if a coordinate is out of image, use the end coordinate of image. roi_slices: list of slices for each of the spatial dimensions. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will remain + unchanged. """ roi_start_torch: torch.Tensor @@ -880,6 +884,9 @@ class RandCropByPosNegLabel(Randomizable, Transform): `image_threshold`, and randomly select crop centers based on them, need to provide `fg_indices` and `bg_indices` together, expect to be 1 dim array of spatial indices after flattening. a typical usage is to call `FgBgToIndices` transform first and cache the results. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). Raises: ValueError: When ``pos`` or ``neg`` are negative. @@ -900,6 +907,7 @@ def __init__( image_threshold: float = 0.0, fg_indices: Optional[NdarrayOrTensor] = None, bg_indices: Optional[NdarrayOrTensor] = None, + allow_smaller: bool = False, ) -> None: self.spatial_size = ensure_tuple(spatial_size) self.label = label @@ -914,6 +922,7 @@ def __init__( self.centers: Optional[List[List[int]]] = None self.fg_indices = fg_indices self.bg_indices = bg_indices + self.allow_smaller = allow_smaller def randomize( self, @@ -933,7 +942,14 @@ def randomize( fg_indices_ = fg_indices bg_indices_ = bg_indices self.centers = generate_pos_neg_label_crop_centers( - self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R + self.spatial_size, + self.num_samples, + self.pos_ratio, + label.shape[1:], + fg_indices_, + bg_indices_, + self.R, + self.allow_smaller, ) def __call__( @@ -1031,6 +1047,9 @@ class RandCropByLabelClasses(Randomizable, Transform): `image_threshold`, and randomly select crop centers based on them, expect to be 1 dim array of spatial indices after flattening. a typical usage is to call `ClassesToIndices` transform first and cache the results for better performance. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will remain + unchanged. """ @@ -1046,6 +1065,7 @@ def __init__( image: Optional[NdarrayOrTensor] = None, image_threshold: float = 0.0, indices: Optional[List[NdarrayOrTensor]] = None, + allow_smaller: bool = False, ) -> None: self.spatial_size = ensure_tuple(spatial_size) self.ratios = ratios @@ -1056,6 +1076,7 @@ def __init__( self.image_threshold = image_threshold self.centers: Optional[List[List[int]]] = None self.indices = indices + self.allow_smaller = allow_smaller def randomize( self, @@ -1073,7 +1094,7 @@ def randomize( else: indices_ = indices self.centers = generate_label_classes_crop_centers( - self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R + self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R, self.allow_smaller ) def __call__( diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index bf0b9ef04d..1f3f5fa3e3 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -1065,6 +1065,9 @@ class RandCropByPosNegLabeld(Randomizable, MapTransform, InvertibleTransform): meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. used to add `patch_index` to the meta dict. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). allow_missing_keys: don't raise exception if key is missing. Raises: @@ -1089,6 +1092,7 @@ def __init__( bg_indices_key: Optional[str] = None, meta_keys: Optional[KeysCollection] = None, meta_key_postfix: str = "meta_dict", + allow_smaller: bool = False, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) @@ -1109,6 +1113,7 @@ def __init__( raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) self.centers: Optional[List[List[int]]] = None + self.allow_smaller = allow_smaller def randomize( self, @@ -1124,7 +1129,14 @@ def randomize( fg_indices_ = fg_indices bg_indices_ = bg_indices self.centers = generate_pos_neg_label_crop_centers( - self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R + self.spatial_size, + self.num_samples, + self.pos_ratio, + label.shape[1:], + fg_indices_, + bg_indices_, + self.R, + self.allow_smaller, ) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]: @@ -1257,6 +1269,9 @@ class RandCropByLabelClassesd(Randomizable, MapTransform, InvertibleTransform): meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according to the key data, default is `meta_dict`, the meta data is a dictionary object. used to add `patch_index` to the meta dict. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will remain + unchanged. allow_missing_keys: don't raise exception if key is missing. """ @@ -1276,6 +1291,7 @@ def __init__( indices_key: Optional[str] = None, meta_keys: Optional[KeysCollection] = None, meta_key_postfix: str = "meta_dict", + allow_smaller: bool = False, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) @@ -1292,6 +1308,7 @@ def __init__( raise ValueError("meta_keys should have the same length as keys.") self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) self.centers: Optional[List[List[int]]] = None + self.allow_smaller = allow_smaller def randomize( self, @@ -1305,7 +1322,7 @@ def randomize( else: indices_ = indices self.centers = generate_label_classes_crop_centers( - self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R + self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R, self.allow_smaller ) def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, NdarrayOrTensor]]: diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 75260af5b6..17dd873b16 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -406,6 +406,7 @@ def correct_crop_centers( centers: List[Union[int, torch.Tensor]], spatial_size: Union[Sequence[int], int], label_spatial_shape: Sequence[int], + allow_smaller: bool = False, ): """ Utility to correct the crop center if the crop size is bigger than the image size. @@ -414,11 +415,16 @@ def correct_crop_centers( centers: pre-computed crop centers of every dim, will correct based on the valid region. spatial_size: spatial size of the ROIs to be sampled. label_spatial_shape: spatial shape of the original label data to compare with ROI. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). """ spatial_size = fall_back_tuple(spatial_size, default=label_spatial_shape) - if not (np.subtract(label_spatial_shape, spatial_size) >= 0).all(): - raise ValueError("The size of the proposed random crop ROI is larger than the image size.") + if any(np.subtract(label_spatial_shape, spatial_size) < 0): + if not allow_smaller: + raise ValueError("The size of the proposed random crop ROI is larger than the image size.") + spatial_size = tuple(min(l, s) for l, s in zip(label_spatial_shape, spatial_size)) # Select subregion to assure valid roi valid_start = np.floor_divide(spatial_size, 2) @@ -450,6 +456,7 @@ def generate_pos_neg_label_crop_centers( fg_indices: NdarrayOrTensor, bg_indices: NdarrayOrTensor, rand_state: Optional[np.random.RandomState] = None, + allow_smaller: bool = False, ) -> List[List[int]]: """ Generate valid sample locations based on the label with option for specifying foreground ratio @@ -463,6 +470,9 @@ def generate_pos_neg_label_crop_centers( fg_indices: pre-computed foreground indices in 1 dimension. bg_indices: pre-computed background indices in 1 dimension. rand_state: numpy randomState object to align with other modules. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). Raises: ValueError: When the proposed roi is larger than the image. @@ -491,7 +501,7 @@ def generate_pos_neg_label_crop_centers( idx = indices_to_use[random_int] center = unravel_index(idx, label_spatial_shape) # shift center to range of valid centers - centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape)) + centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller)) return centers @@ -503,6 +513,7 @@ def generate_label_classes_crop_centers( indices: Sequence[NdarrayOrTensor], ratios: Optional[List[Union[float, int]]] = None, rand_state: Optional[np.random.RandomState] = None, + allow_smaller: bool = False, ) -> List[List[int]]: """ Generate valid sample locations based on the specified ratios of label classes. @@ -516,6 +527,9 @@ def generate_label_classes_crop_centers( ratios: ratios of every class in the label to generate crop centers, including background class. if None, every class will have the same ratio to generate crop centers. rand_state: numpy randomState object to align with other modules. + allow_smaller: if `False`, an exception will be raised if the image is smaller than + the requested ROI in any dimension. If `True`, any smaller dimensions will be set to + match the cropped size (i.e., no cropping in that dimension). """ if rand_state is None: @@ -543,7 +557,7 @@ def generate_label_classes_crop_centers( center = unravel_index(indices_to_use[random_int], label_spatial_shape) # shift center to range of valid centers center_ori = list(center) - centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape)) + centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape, allow_smaller)) return centers diff --git a/tests/test_rand_crop_by_label_classes.py b/tests/test_rand_crop_by_label_classes.py index d562a44a6d..c987c3f0fd 100644 --- a/tests/test_rand_crop_by_label_classes.py +++ b/tests/test_rand_crop_by_label_classes.py @@ -76,6 +76,50 @@ (3, 2, 2, 2), ] ) + TESTS_SHAPE.append( + [ + # provide label at runtime + { + "label": None, + "num_classes": 2, + "spatial_size": [4, 4, 2], + "ratios": [1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + "allow_smaller": True, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + }, + list, + (3, 3, 3, 2), + ] + ) + TESTS_SHAPE.append( + [ + # provide label at runtime + { + "label": None, + "num_classes": 2, + "spatial_size": [4, 4, 4], + "ratios": [1, 1], + "num_samples": 2, + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image_threshold": 0, + "allow_smaller": True, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + }, + list, + (3, 3, 3, 3), + ] + ) class TestRandCropByLabelClasses(unittest.TestCase): diff --git a/tests/test_rand_crop_by_label_classesd.py b/tests/test_rand_crop_by_label_classesd.py index 27fe3425dd..e51413a8d0 100644 --- a/tests/test_rand_crop_by_label_classesd.py +++ b/tests/test_rand_crop_by_label_classesd.py @@ -65,6 +65,54 @@ ] ) + TESTS.append( + [ + # Argmax label + { + "keys": "img", + "label_key": "label", + "num_classes": 2, + "spatial_size": [4, 4, 2], + "ratios": [1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + "allow_smaller": True, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + }, + list, + (3, 3, 3, 2), + ] + ) + + TESTS.append( + [ + # Argmax label + { + "keys": "img", + "label_key": "label", + "num_classes": 2, + "spatial_size": [4, 4, 4], + "ratios": [1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + "allow_smaller": True, + }, + { + "img": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "image": p(np.random.randint(0, 2, size=[3, 3, 3, 3])), + "label": p(np.random.randint(0, 2, size=[1, 3, 3, 3])), + }, + list, + (3, 3, 3, 3), + ] + ) + class TestRandCropByLabelClassesd(unittest.TestCase): @parameterized.expand(TESTS) diff --git a/tests/test_rand_crop_by_pos_neg_label.py b/tests/test_rand_crop_by_pos_neg_label.py index a81976dea1..42a72ccf2b 100644 --- a/tests/test_rand_crop_by_pos_neg_label.py +++ b/tests/test_rand_crop_by_pos_neg_label.py @@ -68,6 +68,36 @@ (3, 2, 2, 2), ] ) +TESTS.append( + [ + { + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "spatial_size": [4, 4, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "allow_smaller": True, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + (3, 3, 3, 2), + ] +) +TESTS.append( + [ + { + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "spatial_size": [4, 4, 4], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "allow_smaller": True, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + (3, 3, 3, 3), + ] +) class TestRandCropByPosNegLabel(unittest.TestCase): diff --git a/tests/test_rand_crop_by_pos_neg_labeld.py b/tests/test_rand_crop_by_pos_neg_labeld.py index 6d2f39cc54..c200b8acac 100644 --- a/tests/test_rand_crop_by_pos_neg_labeld.py +++ b/tests/test_rand_crop_by_pos_neg_labeld.py @@ -76,6 +76,46 @@ }, (3, 2, 2, 2), ], + [ + { + "keys": ["image", "extra", "label"], + "label_key": "label", + "spatial_size": [4, 4, 2], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image_key": None, + "image_threshold": 0, + "allow_smaller": True, + }, + { + "image": np.zeros([3, 3, 3, 3]) - 1, + "extra": np.zeros([3, 3, 3, 3]), + "label": np.ones([3, 3, 3, 3]), + "extra_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, + }, + (3, 3, 3, 2), + ], + [ + { + "keys": ["image", "extra", "label"], + "label_key": "label", + "spatial_size": [4, 4, 4], + "pos": 1, + "neg": 1, + "num_samples": 2, + "image_key": None, + "image_threshold": 0, + "allow_smaller": True, + }, + { + "image": np.zeros([3, 3, 3, 3]) - 1, + "extra": np.zeros([3, 3, 3, 3]), + "label": np.ones([3, 3, 3, 3]), + "extra_meta_dict": {"affine": np.eye(3), "shape": "CHWD"}, + }, + (3, 3, 3, 3), + ], ]