diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 8a890192c8..636d77d187 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -120,6 +120,12 @@ Crop and Pad :members: :special-members: __call__ +`RandCropByLabelClasses` +"""""""""""""""""""""""" +.. autoclass:: RandCropByLabelClasses + :members: + :special-members: __call__ + `ResizeWithPadOrCrop` """"""""""""""""""""" .. autoclass:: ResizeWithPadOrCrop @@ -604,6 +610,12 @@ Utility :members: :special-members: __call__ +`ClassesToIndices` +"""""""""""""""""" +.. autoclass:: ClassesToIndices + :members: + :special-members: __call__ + `ConvertToMultiChannelBasedOnBratsClasses` """""""""""""""""""""""""""""""""""""""""" .. autoclass:: ConvertToMultiChannelBasedOnBratsClasses @@ -700,6 +712,12 @@ Crop and Pad (Dict) :members: :special-members: __call__ +`RandCropByLabelClassesd` +""""""""""""""""""""""""" +.. autoclass:: RandCropByLabelClassesd + :members: + :special-members: __call__ + `ResizeWithPadOrCropd` """""""""""""""""""""" .. autoclass:: ResizeWithPadOrCropd @@ -1183,6 +1201,12 @@ Utility (Dict) :members: :special-members: __call__ +`ClassesToIndicesd` +""""""""""""""""""" +.. autoclass:: ClassesToIndicesd + :members: + :special-members: __call__ + `ConvertToMultiChannelBasedOnBratsClassesd` """"""""""""""""""""""""""""""""""""""""""" .. autoclass:: ConvertToMultiChannelBasedOnBratsClassesd diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index dbcb3daa72..21cfce2b82 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -18,6 +18,7 @@ CenterSpatialCrop, CropForeground, DivisiblePad, + RandCropByLabelClasses, RandCropByPosNegLabel, RandScaleCrop, RandSpatialCrop, @@ -48,6 +49,9 @@ DivisiblePadD, DivisiblePadDict, NumpyPadModeSequence, + RandCropByLabelClassesd, + RandCropByLabelClassesD, + RandCropByLabelClassesDict, RandCropByPosNegLabeld, RandCropByPosNegLabelD, RandCropByPosNegLabelDict, @@ -305,6 +309,7 @@ AsChannelFirst, AsChannelLast, CastToType, + ClassesToIndices, ConvertToMultiChannelBasedOnBratsClasses, DataStats, EnsureChannelFirst, @@ -342,6 +347,9 @@ CastToTyped, CastToTypeD, CastToTypeDict, + ClassesToIndicesd, + ClassesToIndicesD, + ClassesToIndicesDict, ConcatItemsd, ConcatItemsD, ConcatItemsDict, @@ -435,6 +443,7 @@ create_shear, create_translate, extreme_points_to_image, + generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, get_extreme_points, @@ -444,6 +453,7 @@ is_empty, is_positive, map_binary_to_indices, + map_classes_to_indices, map_spatial_axes, rand_choice, rescale_array, diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index a37e9b9791..0b08f5099a 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -25,10 +25,12 @@ from monai.transforms.transform import Randomizable, Transform from monai.transforms.utils import ( compute_divisible_spatial_size, + generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, is_positive, map_binary_to_indices, + map_classes_to_indices, weighted_patch_samples, ) from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple @@ -46,6 +48,7 @@ "CropForeground", "RandWeightedCrop", "RandCropByPosNegLabel", + "RandCropByLabelClasses", "ResizeWithPadOrCrop", "BoundingRect", ] @@ -766,7 +769,11 @@ def randomize( ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) if fg_indices is None or bg_indices is None: - fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) + if self.fg_indices is not None and self.bg_indices is not None: + fg_indices_ = self.fg_indices + bg_indices_ = self.bg_indices + else: + fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) else: fg_indices_ = fg_indices bg_indices_ = bg_indices @@ -802,12 +809,7 @@ def __call__( raise ValueError("label should be provided.") if image is None: image = self.image - if fg_indices is None or bg_indices is None: - if self.fg_indices is not None and self.bg_indices is not None: - fg_indices = self.fg_indices - bg_indices = self.bg_indices - else: - fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold) + self.randomize(label, fg_indices, bg_indices, image) results: List[np.ndarray] = [] if self.centers is not None: @@ -818,6 +820,139 @@ def __call__( return results +class RandCropByLabelClasses(Randomizable, Transform): + """ + Crop random fixed sized regions with the center being a class based on the specified ratios of every class. + The label data can be One-Hot format array or Argmax data. And will return a list of arrays for all the + cropped images. For example, crop two (3 x 3) arrays from (5 x 5) array with `ratios=[1, 2, 3, 1]`:: + + image = np.array([ + [[0.0, 0.3, 0.4, 0.2, 0.0], + [0.0, 0.1, 0.2, 0.1, 0.4], + [0.0, 0.3, 0.5, 0.2, 0.0], + [0.1, 0.2, 0.1, 0.1, 0.0], + [0.0, 0.1, 0.2, 0.1, 0.0]] + ]) + label = np.array([ + [[0, 0, 0, 0, 0], + [0, 1, 2, 1, 0], + [0, 1, 3, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]] + ]) + cropper = RandCropByLabelClasses( + spatial_size=[3, 3], + ratios=[1, 2, 3, 1], + num_classes=4, + num_samples=2, + ) + label_samples = cropper(img=label, label=label, image=image) + + The 2 randomly cropped samples of `label` can be: + [[0, 1, 2], [[0, 0, 0], + [0, 1, 3], [1, 2, 1], + [0, 0, 0]] [1, 3, 0]] + + If a dimension of the expected spatial size is bigger than the input image size, + will not crop that dimension. So the cropped result may be smaller than expected size, and the cropped + results of several images may not have exactly same shape. + + Args: + spatial_size: the spatial size of the crop region e.g. [224, 224, 128]. + if a dimension of ROI size is bigger than image size, will not crop that dimension of the image. + if its components have non-positive values, the corresponding size of `label` will be used. + for example: if the spatial size of input data is [40, 40, 40] and `spatial_size=[32, 64, -1]`, + the spatial size of output data will be [32, 40, 40]. + ratios: specified ratios of every class in the label to generate crop centers, including background class. + if None, every class will have the same ratio to generate crop centers. + label: the label image that is used for finding every classes, if None, must set at `self.__call__`. + num_classes: number of classes for argmax label, not necessary for One-Hot label. + num_samples: number of samples (crop regions) to take in each list. + image: if image is not None, only return the indices of every class that are within the valid + region of the image (``image > image_threshold``). + image_threshold: if enabled `image`, use ``image > image_threshold`` to + determine the valid image content area and select class indices only in this area. + indices: if provided pre-computed indices of every class, will ignore above `image` and + `image_threshold`, and randomly select crop centers based on them, expect to be 1 dim array + of spatial indices after flattening. a typical usage is to call `ClassesToIndices` transform first + and cache the results for better performance. + + """ + + def __init__( + self, + spatial_size: Union[Sequence[int], int], + ratios: Optional[List[Union[float, int]]] = None, + label: Optional[np.ndarray] = None, + num_classes: Optional[int] = None, + num_samples: int = 1, + image: Optional[np.ndarray] = None, + image_threshold: float = 0.0, + indices: Optional[List[np.ndarray]] = None, + ) -> None: + self.spatial_size = ensure_tuple(spatial_size) + self.ratios = ratios + self.label = label + self.num_classes = num_classes + self.num_samples = num_samples + self.image = image + self.image_threshold = image_threshold + self.centers: Optional[List[List[np.ndarray]]] = None + self.indices = indices + + def randomize( + self, + label: np.ndarray, + indices: Optional[List[np.ndarray]] = None, + image: Optional[np.ndarray] = None, + ) -> None: + self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) + indices_: List[np.ndarray] + if indices is None: + if self.indices is not None: + indices_ = self.indices + else: + indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) + else: + indices_ = indices + self.centers = generate_label_classes_crop_centers( + self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R + ) + + def __call__( + self, + img: np.ndarray, + label: Optional[np.ndarray] = None, + image: Optional[np.ndarray] = None, + indices: Optional[List[np.ndarray]] = None, + ) -> List[np.ndarray]: + """ + Args: + img: input data to crop samples from based on the ratios of every class, assumes `img` is a + channel-first array. + label: the label image that is used for finding indices of every class, if None, use `self.label`. + image: optional image data to help select valid area, can be same as `img` or another image array. + use ``image > image_threshold`` to select the centers only in valid region. if None, use `self.image`. + indices: list of indices for every class in the image, used to randomly select crop centers. + + """ + if label is None: + label = self.label + if label is None: + raise ValueError("label should be provided.") + if image is None: + image = self.image + + self.randomize(label, indices, image) + results: List[np.ndarray] = [] + if self.centers is not None: + for center in self.centers: + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + results.append(cropper(img)) + + return results + + class ResizeWithPadOrCrop(Transform): """ Resize an image to a target spatial size by either centrally cropping the image or diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index d7f40233e5..0717ea2cc9 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -40,9 +40,11 @@ from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.utils import ( allow_missing_keys_mode, + generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, is_positive, map_binary_to_indices, + map_classes_to_indices, weighted_patch_samples, ) from monai.utils import ImageMetaKey as Key @@ -1091,9 +1093,188 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n self.randomize(label, fg_indices, bg_indices, image) if not isinstance(self.spatial_size, tuple): - raise AssertionError + raise ValueError("spatial_size must be a valid tuple.") if self.centers is None: - raise AssertionError + raise ValueError("no available ROI centers to crop.") + + # initialize returned list with shallow copy to preserve key ordering + results: List[Dict[Hashable, np.ndarray]] = [dict(data) for _ in range(self.num_samples)] + + for i, center in enumerate(self.centers): + # fill in the extra keys with unmodified data + for key in set(data.keys()).difference(set(self.keys)): + results[i][key] = deepcopy(data[key]) + for key in self.key_iterator(d): + img = d[key] + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + orig_size = img.shape[1:] + results[i][key] = cropper(img) + self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size) + # add `patch_index` to the meta data + for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): + meta_key = meta_key or f"{key}_{meta_key_postfix}" + if meta_key not in results[i]: + results[i][meta_key] = {} # type: ignore + results[i][meta_key][Key.PATCH_INDEX] = i + + return results + + def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE]) + current_size = np.asarray(d[key].shape[1:]) + center = transform[InverseKeys.EXTRA_INFO]["center"] + cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore + # get required pad to start and end + pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) + pad_to_end = orig_size - current_size - pad_to_start + # interleave mins and maxes + pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) + inverse_transform = BorderPad(pad) + # Apply inverse transform + d[key] = inverse_transform(d[key]) + # Remove the applied transform + self.pop_transform(d, key) + + return d + + +class RandCropByLabelClassesd(Randomizable, MapTransform, InvertibleTransform): + """ + Dictionary-based version :py:class:`monai.transforms.RandCropByLabelClasses`. + Crop random fixed sized regions with the center being a class based on the specified ratios of every class. + The label data can be One-Hot format array or Argmax data. And will return a list of arrays for all the + cropped images. For example, crop two (3 x 3) arrays from (5 x 5) array with `ratios=[1, 2, 3, 1]`:: + + cropper = RandCropByLabelClassesd( + keys=["image", "label"], + label_key="label", + spatial_size=[3, 3], + ratios=[1, 2, 3, 1], + num_classes=4, + num_samples=2, + ) + data = { + "image": np.array([ + [[0.0, 0.3, 0.4, 0.2, 0.0], + [0.0, 0.1, 0.2, 0.1, 0.4], + [0.0, 0.3, 0.5, 0.2, 0.0], + [0.1, 0.2, 0.1, 0.1, 0.0], + [0.0, 0.1, 0.2, 0.1, 0.0]] + ]), + "label": np.array([ + [[0, 0, 0, 0, 0], + [0, 1, 2, 1, 0], + [0, 1, 3, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]] + ]), + } + result = cropper(data) + + The 2 randomly cropped samples of `label` can be: + [[0, 1, 2], [[0, 0, 0], + [0, 1, 3], [1, 2, 1], + [0, 0, 0]] [1, 3, 0]] + + If a dimension of the expected spatial size is bigger than the input image size, + will not crop that dimension. So the cropped result may be smaller than expected size, and the cropped + results of several images may not have exactly same shape. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + label_key: name of key for label image, this will be used for finding indices of every class. + spatial_size: the spatial size of the crop region e.g. [224, 224, 128]. + if a dimension of ROI size is bigger than image size, will not crop that dimension of the image. + if its components have non-positive values, the corresponding size of `label` will be used. + for example: if the spatial size of input data is [40, 40, 40] and `spatial_size=[32, 64, -1]`, + the spatial size of output data will be [32, 40, 40]. + ratios: specified ratios of every class in the label to generate crop centers, including background class. + if None, every class will have the same ratio to generate crop centers. + num_classes: number of classes for argmax label, not necessary for One-Hot label. + num_samples: number of samples (crop regions) to take in each list. + image_key: if image_key is not None, only return the indices of every class that are within the valid + region of the image (``image > image_threshold``). + image_threshold: if enabled `image_key`, use ``image > image_threshold`` to + determine the valid image content area and select class indices only in this area. + indices_key: if provided pre-computed indices of every class, will ignore above `image` and + `image_threshold`, and randomly select crop centers based on them, expect to be 1 dim array + of spatial indices after flattening. a typical usage is to call `ClassesToIndices` transform first + and cache the results for better performance. + meta_keys: explicitly indicate the key of the corresponding meta data dictionary. + used to add `patch_index` to the meta dict. + for example, for data with key `image`, the metadata by default is in `image_meta_dict`. + the meta data is a dictionary object which contains: filename, original_shape, etc. + it can be a sequence of string, map to the `keys`. + if None, will try to construct meta_keys by `key_{meta_key_postfix}`. + meta_key_postfix: if meta_keys is None, use `key_{postfix}` to to fetch the meta data according + to the key data, default is `meta_dict`, the meta data is a dictionary object. + used to add `patch_index` to the meta dict. + allow_missing_keys: don't raise exception if key is missing. + + """ + + def __init__( + self, + keys: KeysCollection, + label_key: str, + spatial_size: Union[Sequence[int], int], + ratios: Optional[List[Union[float, int]]] = None, + num_classes: Optional[int] = None, + num_samples: int = 1, + image_key: Optional[str] = None, + image_threshold: float = 0.0, + indices_key: Optional[str] = None, + meta_keys: Optional[KeysCollection] = None, + meta_key_postfix: str = "meta_dict", + allow_missing_keys: bool = False, + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + self.label_key = label_key + self.spatial_size: Union[Tuple[int, ...], Sequence[int], int] = spatial_size + self.ratios = ratios + self.num_classes = num_classes + self.num_samples = num_samples + self.image_key = image_key + self.image_threshold = image_threshold + self.indices_key = indices_key + self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) + if len(self.keys) != len(self.meta_keys): + raise ValueError("meta_keys should have the same length as keys.") + self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) + self.centers: Optional[List[List[np.ndarray]]] = None + + def randomize( + self, + label: np.ndarray, + indices: Optional[List[np.ndarray]] = None, + image: Optional[np.ndarray] = None, + ) -> None: + self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) + indices_: List[np.ndarray] + if indices is None: + indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) + else: + indices_ = indices + self.centers = generate_label_classes_crop_centers( + self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R + ) + + def __call__(self, data: Mapping[Hashable, Any]) -> List[Dict[Hashable, np.ndarray]]: + d = dict(data) + label = d[self.label_key] + image = d[self.image_key] if self.image_key else None + indices = d.get(self.indices_key) if self.indices_key is not None else None + + self.randomize(label, indices, image) + if not isinstance(self.spatial_size, tuple): + raise ValueError("spatial_size must be a valid tuple.") + if self.centers is None: + raise ValueError("no available ROI centers to crop.") # initialize returned list with shallow copy to preserve key ordering results: List[Dict[Hashable, np.ndarray]] = [dict(data) for _ in range(self.num_samples)] @@ -1270,5 +1451,6 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda CropForegroundD = CropForegroundDict = CropForegroundd RandWeightedCropD = RandWeightedCropDict = RandWeightedCropd RandCropByPosNegLabelD = RandCropByPosNegLabelDict = RandCropByPosNegLabeld +RandCropByLabelClassesD = RandCropByLabelClassesDict = RandCropByLabelClassesd ResizeWithPadOrCropD = ResizeWithPadOrCropDict = ResizeWithPadOrCropd BoundingRectD = BoundingRectDict = BoundingRectd diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 1d82bb3a44..4a73153859 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -30,6 +30,7 @@ extreme_points_to_image, get_extreme_points, map_binary_to_indices, + map_classes_to_indices, ) from monai.utils import ensure_tuple, issequenceiterable, min_version, optional_import @@ -59,6 +60,7 @@ "Lambda", "LabelToMask", "FgBgToIndices", + "ClassesToIndices", "ConvertToMultiChannelBasedOnBratsClasses", "AddExtremePointsChannel", "TorchVision", @@ -708,6 +710,54 @@ def __call__( return fg_indices, bg_indices +class ClassesToIndices(Transform): + def __init__( + self, + num_classes: Optional[int] = None, + image_threshold: float = 0.0, + output_shape: Optional[Sequence[int]] = None, + ) -> None: + """ + Compute indices of every class of the input label data, return a list of indices. + If no output_shape specified, output data will be 1 dim indices after flattening. + This transform can help pre-compute indices of the class regions for other transforms. + A typical usage is to randomly select indices of classes to crop. + The main logic is based on :py:class:`monai.transforms.utils.map_classes_to_indices`. + + Args: + num_classes: number of classes for argmax label, not necessary for One-Hot label. + image_threshold: if enabled `image` at runtime, use ``image > image_threshold`` to + determine the valid image content area and select only the indices of classes in this area. + output_shape: expected shape of output indices. if not None, unravel indices to specified shape. + + """ + self.num_classes = num_classes + self.image_threshold = image_threshold + self.output_shape = output_shape + + def __call__( + self, + label: np.ndarray, + image: Optional[np.ndarray] = None, + output_shape: Optional[Sequence[int]] = None, + ) -> List[np.ndarray]: + """ + Args: + label: input data to compute the indices of every class. + image: if image is not None, use ``image > image_threshold`` to define valid region, and only select + the indices within the valid region. + output_shape: expected shape of output indices. if None, use `self.output_shape` instead. + + """ + if output_shape is None: + output_shape = self.output_shape + indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) + if output_shape is not None: + indices = [np.stack([np.unravel_index(i, output_shape) for i in array]) for array in indices] + + return indices + + class ConvertToMultiChannelBasedOnBratsClasses(Transform): """ Convert labels to multi channels based on brats18 classes: diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 13cef89789..6fa672e6c4 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -31,6 +31,7 @@ AsChannelFirst, AsChannelLast, CastToType, + ClassesToIndices, ConvertToMultiChannelBasedOnBratsClasses, DataStats, EnsureChannelFirst, @@ -977,6 +978,49 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda return d +class ClassesToIndicesd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ClassesToIndices`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + indices_postfix: postfix to save the computed indices of all classes in dict. + for example, if computed on `label` and `postfix = "_cls_indices"`, the key will be `label_cls_indices`. + num_classes: number of classes for argmax label, not necessary for One-Hot label. + image_key: if image_key is not None, use ``image > image_threshold`` to define valid region, and only select + the indices within the valid region. + image_threshold: if enabled image_key, use ``image > image_threshold`` to determine the valid image content + area and select only the indices of classes in this area. + output_shape: expected shape of output indices. if not None, unravel indices to specified shape. + allow_missing_keys: don't raise exception if key is missing. + + """ + + def __init__( + self, + keys: KeysCollection, + indices_postfix: str = "_cls_indices", + num_classes: Optional[int] = None, + image_key: Optional[str] = None, + image_threshold: float = 0.0, + output_shape: Optional[Sequence[int]] = None, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.indices_postfix = indices_postfix + self.image_key = image_key + self.converter = ClassesToIndices(num_classes, image_threshold, output_shape) + + def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, np.ndarray]: + d = dict(data) + image = d[self.image_key] if self.image_key else None + for key in self.key_iterator(d): + d[str(key) + self.indices_postfix] = self.converter(d[key], image) + + return d + + class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.ConvertToMultiChannelBasedOnBratsClasses`. @@ -1203,6 +1247,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda LambdaD = LambdaDict = Lambdad LabelToMaskD = LabelToMaskDict = LabelToMaskd FgBgToIndicesD = FgBgToIndicesDict = FgBgToIndicesd +ClassesToIndicesD = ClassesToIndicesDict = ClassesToIndicesd ConvertToMultiChannelBasedOnBratsClassesD = ( ConvertToMultiChannelBasedOnBratsClassesDict ) = ConvertToMultiChannelBasedOnBratsClassesd diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 94c3f4a238..9506619b29 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -54,8 +54,10 @@ "compute_divisible_spatial_size", "resize_center", "map_binary_to_indices", + "map_classes_to_indices", "weighted_patch_samples", "generate_pos_neg_label_crop_centers", + "generate_label_classes_crop_centers", "create_grid", "create_control_grid", "create_rotate", @@ -270,9 +272,57 @@ def map_binary_to_indices( bg_indices = np.nonzero(np.logical_and(img_flat, ~label_flat))[0] else: bg_indices = np.nonzero(~label_flat)[0] + return fg_indices, bg_indices +def map_classes_to_indices( + label: np.ndarray, + num_classes: Optional[int] = None, + image: Optional[np.ndarray] = None, + image_threshold: float = 0.0, +) -> List[np.ndarray]: + """ + Filter out indices of every class of the input label data, return the indices after fattening. + It can handle both One-Hot format label and Argmax format label, must provide `num_classes` for + Argmax label. + + For example: + ``label = np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])`` and `num_classes=3`, will return a list + which contains the indices of the 3 classes: + ``[np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])]`` + + Args: + label: use the label data to get the indices of every class. + num_classes: number of classes for argmax label, not necessary for One-Hot label. + image: if image is not None, only return the indices of every class that are within the valid + region of the image (``image > image_threshold``). + image_threshold: if enabled `image`, use ``image > image_threshold`` to + determine the valid image content area and select class indices only in this area. + + """ + img_flat: Optional[np.ndarray] = None + if image is not None: + img_flat = np.any(image > image_threshold, axis=0).ravel() + + indices: List[np.ndarray] = [] + # assuming the first dimension is channel + channels = len(label) + + num_classes_: int = channels + if channels == 1: + if num_classes is None: + raise ValueError("if not One-Hot format label, must provide the num_classes.") + num_classes_ = num_classes + + for c in range(num_classes_): + label_flat = np.any(label[c : c + 1] if channels > 1 else label == c, axis=0).ravel() + label_flat = np.logical_and(img_flat, label_flat) if img_flat is not None else label_flat + indices.append(np.nonzero(label_flat)[0]) + + return indices + + def weighted_patch_samples( spatial_size: Union[int, Sequence[int]], w: np.ndarray, @@ -317,6 +367,44 @@ def weighted_patch_samples( return [np.unravel_index(i, v_size) + diff for i in np.asarray(idx, dtype=int)] +def correct_crop_centers( + centers: List[np.ndarray], spatial_size: Union[Sequence[int], int], label_spatial_shape: Sequence[int] +) -> List[np.ndarray]: + """ + Utility to correct the crop center if the crop size is bigger than the image size. + + Args: + ceters: pre-computed crop centers, will correct based on the valid region. + spatial_size: spatial size of the ROIs to be sampled. + label_spatial_shape: spatial shape of the original label data to compare with ROI. + + """ + spatial_size = fall_back_tuple(spatial_size, default=label_spatial_shape) + if not (np.subtract(label_spatial_shape, spatial_size) >= 0).all(): + raise ValueError("The size of the proposed random crop ROI is larger than the image size.") + + # Select subregion to assure valid roi + valid_start = np.floor_divide(spatial_size, 2) + # add 1 for random + valid_end = np.subtract(label_spatial_shape + np.array(1), spatial_size / np.array(2)).astype(np.uint16) + # int generation to have full range on upper side, but subtract unfloored size/2 to prevent rounded range + # from being too high + for i, valid_s in enumerate(valid_start): + # need this because np.random.randint does not work with same start and end + if valid_s == valid_end[i]: + valid_end[i] += 1 + + for i, c in enumerate(centers): + center_i = c + if c < valid_start[i]: + center_i = valid_start[i] + if c >= valid_end[i]: + center_i = valid_end[i] - 1 + centers[i] = center_i + + return centers + + def generate_pos_neg_label_crop_centers( spatial_size: Union[Sequence[int], int], num_samples: int, @@ -346,33 +434,6 @@ def generate_pos_neg_label_crop_centers( """ if rand_state is None: rand_state = np.random.random.__self__ # type: ignore - spatial_size = fall_back_tuple(spatial_size, default=label_spatial_shape) - if not (np.subtract(label_spatial_shape, spatial_size) >= 0).all(): - raise ValueError("The size of the proposed random crop ROI is larger than the image size.") - - # Select subregion to assure valid roi - valid_start = np.floor_divide(spatial_size, 2) - # add 1 for random - valid_end = np.subtract(label_spatial_shape + np.array(1), spatial_size / np.array(2)).astype(np.uint16) - # int generation to have full range on upper side, but subtract unfloored size/2 to prevent rounded range - # from being too high - for i, valid_s in enumerate( - valid_start - ): # need this because np.random.randint does not work with same start and end - if valid_s == valid_end[i]: - valid_end[i] += 1 - - def _correct_centers( - center_ori: List[np.ndarray], valid_start: np.ndarray, valid_end: np.ndarray - ) -> List[np.ndarray]: - for i, c in enumerate(center_ori): - center_i = c - if c < valid_start[i]: - center_i = valid_start[i] - if c >= valid_end[i]: - center_i = valid_end[i] - 1 - center_ori[i] = center_i - return center_ori centers = [] fg_indices, bg_indices = np.asarray(fg_indices), np.asarray(bg_indices) @@ -392,7 +453,61 @@ def _correct_centers( center = np.unravel_index(indices_to_use[random_int], label_spatial_shape) # shift center to range of valid centers center_ori = list(center) - centers.append(_correct_centers(center_ori, valid_start, valid_end)) + centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape)) + + return centers + + +def generate_label_classes_crop_centers( + spatial_size: Union[Sequence[int], int], + num_samples: int, + label_spatial_shape: Sequence[int], + indices: List[np.ndarray], + ratios: Optional[List[Union[float, int]]] = None, + rand_state: Optional[np.random.RandomState] = None, +) -> List[List[np.ndarray]]: + """ + Generate valid sample locations based on the specified ratios of label classes. + Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W] + + Args: + spatial_size: spatial size of the ROIs to be sampled. + num_samples: total sample centers to be generated. + label_spatial_shape: spatial shape of the original label data to unravel selected centers. + indices: sequence of pre-computed foreground indices of every class in 1 dimension. + ratios: ratios of every class in the label to generate crop centers, including background class. + if None, every class will have the same ratio to generate crop centers. + rand_state: numpy randomState object to align with other modules. + + """ + if rand_state is None: + rand_state = np.random.random.__self__ # type: ignore + + if num_samples < 1: + raise ValueError("num_samples must be an int number and greater than 0.") + ratios_: List[Union[float, int]] = ([1] * len(indices)) if ratios is None else ratios + if len(ratios_) != len(indices): + raise ValueError("random crop radios must match the number of indices of classes.") + if any([i < 0 for i in ratios_]): + raise ValueError("ratios should not contain negative number.") + + # ensure indices are numpy array + indices = [np.asarray(i) for i in indices] + for i, array in enumerate(indices): + if len(array) == 0: + warnings.warn(f"no available indices of class {i} to crop, set the crop ratio of this class to zero.") + ratios_[i] = 0 + + centers = [] + classes = rand_state.choice(len(ratios_), size=num_samples, p=np.asarray(ratios_) / np.sum(ratios_)) + for i in classes: + # randomly select the indices of a class based on the ratios + indices_to_use = indices[i] + random_int = rand_state.randint(len(indices_to_use)) + center = np.unravel_index(indices_to_use[random_int], label_spatial_shape) + # shift center to range of valid centers + center_ori = list(center) + centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape)) return centers diff --git a/tests/test_classes_to_indices.py b/tests/test_classes_to_indices.py new file mode 100644 index 0000000000..0ba3dd094a --- /dev/null +++ b/tests/test_classes_to_indices.py @@ -0,0 +1,79 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import ClassesToIndices + +TEST_CASE_1 = [ + # test Argmax data + {"num_classes": 3, "image_threshold": 0.0}, + np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), + None, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], +] + +TEST_CASE_2 = [ + {"num_classes": 3, "image_threshold": 60}, + np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), + np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], +] + +TEST_CASE_3 = [ + # test One-Hot data + {"image_threshold": 0.0}, + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ), + None, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], +] + +TEST_CASE_4 = [ + {"num_classes": None, "image_threshold": 60}, + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ), + np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], +] + +TEST_CASE_5 = [ + # test output_shape + {"num_classes": 3, "image_threshold": 0.0, "output_shape": [3, 3]}, + np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), + None, + [np.array([[0, 0], [1, 1], [2, 2]]), np.array([[0, 1], [1, 2], [2, 0]]), np.array([[0, 2], [1, 0], [2, 1]])], +] + + +class TestClassesToIndices(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_value(self, input_args, label, image, expected_indices): + indices = ClassesToIndices(**input_args)(label, image) + for i, e in zip(indices, expected_indices): + np.testing.assert_allclose(i, e) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_classes_to_indicesd.py b/tests/test_classes_to_indicesd.py new file mode 100644 index 0000000000..67fac95c8c --- /dev/null +++ b/tests/test_classes_to_indicesd.py @@ -0,0 +1,84 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import ClassesToIndicesd + +TEST_CASE_1 = [ + # test Argmax data + {"keys": "label", "num_classes": 3, "image_threshold": 0.0}, + {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])}, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], +] + +TEST_CASE_2 = [ + {"keys": "label", "image_key": "image", "num_classes": 3, "image_threshold": 60}, + { + "label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), + "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + }, + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], +] + +TEST_CASE_3 = [ + # test One-Hot data + {"keys": "label", "image_threshold": 0.0}, + { + "label": np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + }, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], +] + +TEST_CASE_4 = [ + {"keys": "label", "image_key": "image", "num_classes": None, "image_threshold": 60}, + { + "label": np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ), + "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + }, + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], +] + +TEST_CASE_5 = [ + # test output_shape + {"keys": "label", "indices_postfix": "cls", "num_classes": 3, "image_threshold": 0.0, "output_shape": [3, 3]}, + {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])}, + [np.array([[0, 0], [1, 1], [2, 2]]), np.array([[0, 1], [1, 2], [2, 0]]), np.array([[0, 2], [1, 0], [2, 1]])], +] + + +class TestClassesToIndicesd(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_value(self, input_args, input_data, expected_indices): + result = ClassesToIndicesd(**input_args)(input_data) + key_postfix = input_args.get("indices_postfix") + key_postfix = "_cls_indices" if key_postfix is None else key_postfix + for i, e in zip(result["label" + key_postfix], expected_indices): + np.testing.assert_allclose(i, e) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_label_classes_crop_centers.py b/tests/test_generate_label_classes_crop_centers.py new file mode 100644 index 0000000000..38f2a3e0d1 --- /dev/null +++ b/tests/test_generate_label_classes_crop_centers.py @@ -0,0 +1,58 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import generate_label_classes_crop_centers + +TEST_CASE_1 = [ + { + "spatial_size": [2, 2, 2], + "num_samples": 2, + "ratios": [1, 2], + "label_spatial_shape": [3, 3, 3], + "indices": [[3, 12, 21], [1, 9, 18]], + "rand_state": np.random.RandomState(), + }, + list, + 2, + 3, +] + +TEST_CASE_2 = [ + { + "spatial_size": [2, 2, 2], + "num_samples": 1, + "ratios": None, + "label_spatial_shape": [3, 3, 3], + "indices": [[3, 12, 21], [1, 9, 18]], + "rand_state": np.random.RandomState(), + }, + list, + 1, + 3, +] + + +class TestGenerateLabelClassesCropCenters(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_type_shape(self, input_data, expected_type, expected_count, expected_shape): + result = generate_label_classes_crop_centers(**input_data) + self.assertIsInstance(result, expected_type) + self.assertEqual(len(result), expected_count) + self.assertEqual(len(result[0]), expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 0f1d94487d..31ef971078 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -39,6 +39,7 @@ Orientationd, RandAffined, RandAxisFlipd, + RandCropByLabelClassesd, RandCropByPosNegLabeld, RandFlipd, Randomizable, @@ -445,6 +446,15 @@ ) ) +TESTS.append( + ( + "RandCropByLabelClassesd 2d", + "2D", + 1e-7, + RandCropByLabelClassesd(KEYS, "label", (99, 96), ratios=[1, 2, 3, 4, 5], num_classes=5, num_samples=10), + ) +) + TESTS.append( ( "RandCropByPosNegLabeld 2d", @@ -478,6 +488,7 @@ NUM_SAMPLES = 5 N_SAMPLES_TESTS = [ + [RandCropByLabelClassesd(KEYS, "label", (110, 99), [1, 2, 3, 4, 5], num_classes=5, num_samples=NUM_SAMPLES)], [RandCropByPosNegLabeld(KEYS, "label", (110, 99), num_samples=NUM_SAMPLES)], [RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=NUM_SAMPLES, random_size=False)], [RandWeightedCropd(KEYS, "label", (90, 91), num_samples=NUM_SAMPLES)], diff --git a/tests/test_map_classes_to_indices.py b/tests/test_map_classes_to_indices.py new file mode 100644 index 0000000000..2320954520 --- /dev/null +++ b/tests/test_map_classes_to_indices.py @@ -0,0 +1,101 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import map_classes_to_indices + +TEST_CASE_1 = [ + # test Argmax data + {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), "num_classes": 3, "image": None, "image_threshold": 0.0}, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], +] + +TEST_CASE_2 = [ + { + "label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), + "num_classes": 3, + "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + "image_threshold": 60, + }, + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], +] + +TEST_CASE_3 = [ + # test One-Hot data + { + "label": np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ), + "image": None, + "image_threshold": 0.0, + }, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], +] + +TEST_CASE_4 = [ + { + "label": np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ), + "num_classes": None, + "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), + "image_threshold": 60, + }, + [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], +] + +TEST_CASE_5 = [ + # test empty class + {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), "num_classes": 5, "image": None, "image_threshold": 0.0}, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])], +] + +TEST_CASE_6 = [ + # test empty class + { + "label": np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + ] + ), + "image": None, + "image_threshold": 0.0, + }, + [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])], +] + + +class TestMapClassesToIndices(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + def test_value(self, input_data, expected_indices): + indices = map_classes_to_indices(**input_data) + for i, e in zip(indices, expected_indices): + np.testing.assert_allclose(i, e) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_crop_by_label_classes.py b/tests/test_rand_crop_by_label_classes.py new file mode 100644 index 0000000000..b21f971042 --- /dev/null +++ b/tests/test_rand_crop_by_label_classes.py @@ -0,0 +1,93 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import ClassesToIndices, RandCropByLabelClasses + +TEST_CASE_0 = [ + # One-Hot label + { + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "num_classes": None, + "spatial_size": [2, 2, -1], + "ratios": [1, 1, 1], + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + list, + (3, 2, 2, 3), +] + +TEST_CASE_1 = [ + # Argmax label + { + "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + {"img": np.random.randint(0, 2, size=[3, 3, 3, 3])}, + list, + (3, 2, 2, 2), +] + +TEST_CASE_2 = [ + # provide label at runtime + { + "label": None, + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image_threshold": 0, + }, + { + "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + }, + list, + (3, 2, 2, 2), +] + + +class TestRandCropByLabelClasses(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]) + def test_type_shape(self, input_param, input_data, expected_type, expected_shape): + result = RandCropByLabelClasses(**input_param)(**input_data) + self.assertIsInstance(result, expected_type) + self.assertTupleEqual(result[0].shape, expected_shape) + + @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + def test_indices(self, input_param, input_data, expected_type, expected_shape): + input_param["indices"] = ClassesToIndices(num_classes=input_param["num_classes"])(input_param["label"]) + result = RandCropByLabelClasses(**input_param)(**input_data) + self.assertIsInstance(result, expected_type) + self.assertTupleEqual(result[0].shape, expected_shape) + # test set indices at runtime + input_data["indices"] = input_param["indices"] + result = RandCropByLabelClasses(**input_param)(**input_data) + self.assertIsInstance(result, expected_type) + self.assertTupleEqual(result[0].shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_crop_by_label_classesd.py b/tests/test_rand_crop_by_label_classesd.py new file mode 100644 index 0000000000..829096953b --- /dev/null +++ b/tests/test_rand_crop_by_label_classesd.py @@ -0,0 +1,77 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms import ClassesToIndicesd, RandCropByLabelClassesd + +TEST_CASE_0 = [ + # One-Hot label + { + "keys": "img", + "label_key": "label", + "num_classes": None, + "spatial_size": [2, 2, -1], + "ratios": [1, 1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + }, + { + "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label": np.random.randint(0, 2, size=[3, 3, 3, 3]), + }, + list, + (3, 2, 2, 3), +] + +TEST_CASE_1 = [ + # Argmax label + { + "keys": "img", + "label_key": "label", + "num_classes": 2, + "spatial_size": [2, 2, 2], + "ratios": [1, 1], + "num_samples": 2, + "image_key": "image", + "image_threshold": 0, + }, + { + "img": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "image": np.random.randint(0, 2, size=[3, 3, 3, 3]), + "label": np.random.randint(0, 2, size=[1, 3, 3, 3]), + }, + list, + (3, 2, 2, 2), +] + + +class TestRandCropByLabelClassesd(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) + def test_type_shape(self, input_param, input_data, expected_type, expected_shape): + result = RandCropByLabelClassesd(**input_param)(input_data) + self.assertIsInstance(result, expected_type) + self.assertTupleEqual(result[0]["img"].shape, expected_shape) + # test with pre-computed indices + input_data = ClassesToIndicesd(keys="label", num_classes=input_param["num_classes"])(input_data) + input_param["indices_key"] = "label_cls_indices" + result = RandCropByLabelClassesd(**input_param)(input_data) + self.assertIsInstance(result, expected_type) + self.assertTupleEqual(result[0]["img"].shape, expected_shape) + + +if __name__ == "__main__": + unittest.main()