Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ Crop and Pad
:members:
:special-members: __call__

`RandCropByLabelClasses`
""""""""""""""""""""""""
.. autoclass:: RandCropByLabelClasses
:members:
:special-members: __call__

`ResizeWithPadOrCrop`
"""""""""""""""""""""
.. autoclass:: ResizeWithPadOrCrop
Expand Down Expand Up @@ -604,6 +610,12 @@ Utility
:members:
:special-members: __call__

`ClassesToIndices`
""""""""""""""""""
.. autoclass:: ClassesToIndices
:members:
:special-members: __call__

`ConvertToMultiChannelBasedOnBratsClasses`
""""""""""""""""""""""""""""""""""""""""""
.. autoclass:: ConvertToMultiChannelBasedOnBratsClasses
Expand Down Expand Up @@ -700,6 +712,12 @@ Crop and Pad (Dict)
:members:
:special-members: __call__

`RandCropByLabelClassesd`
"""""""""""""""""""""""""
.. autoclass:: RandCropByLabelClassesd
:members:
:special-members: __call__

`ResizeWithPadOrCropd`
""""""""""""""""""""""
.. autoclass:: ResizeWithPadOrCropd
Expand Down Expand Up @@ -1183,6 +1201,12 @@ Utility (Dict)
:members:
:special-members: __call__

`ClassesToIndicesd`
"""""""""""""""""""
.. autoclass:: ClassesToIndicesd
:members:
:special-members: __call__

`ConvertToMultiChannelBasedOnBratsClassesd`
"""""""""""""""""""""""""""""""""""""""""""
.. autoclass:: ConvertToMultiChannelBasedOnBratsClassesd
Expand Down
10 changes: 10 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
CenterSpatialCrop,
CropForeground,
DivisiblePad,
RandCropByLabelClasses,
RandCropByPosNegLabel,
RandScaleCrop,
RandSpatialCrop,
Expand Down Expand Up @@ -48,6 +49,9 @@
DivisiblePadD,
DivisiblePadDict,
NumpyPadModeSequence,
RandCropByLabelClassesd,
RandCropByLabelClassesD,
RandCropByLabelClassesDict,
RandCropByPosNegLabeld,
RandCropByPosNegLabelD,
RandCropByPosNegLabelDict,
Expand Down Expand Up @@ -305,6 +309,7 @@
AsChannelFirst,
AsChannelLast,
CastToType,
ClassesToIndices,
ConvertToMultiChannelBasedOnBratsClasses,
DataStats,
EnsureChannelFirst,
Expand Down Expand Up @@ -342,6 +347,9 @@
CastToTyped,
CastToTypeD,
CastToTypeDict,
ClassesToIndicesd,
ClassesToIndicesD,
ClassesToIndicesDict,
ConcatItemsd,
ConcatItemsD,
ConcatItemsDict,
Expand Down Expand Up @@ -435,6 +443,7 @@
create_shear,
create_translate,
extreme_points_to_image,
generate_label_classes_crop_centers,
generate_pos_neg_label_crop_centers,
generate_spatial_bounding_box,
get_extreme_points,
Expand All @@ -444,6 +453,7 @@
is_empty,
is_positive,
map_binary_to_indices,
map_classes_to_indices,
map_spatial_axes,
rand_choice,
rescale_array,
Expand Down
149 changes: 142 additions & 7 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
from monai.transforms.transform import Randomizable, Transform
from monai.transforms.utils import (
compute_divisible_spatial_size,
generate_label_classes_crop_centers,
generate_pos_neg_label_crop_centers,
generate_spatial_bounding_box,
is_positive,
map_binary_to_indices,
map_classes_to_indices,
weighted_patch_samples,
)
from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple
Expand All @@ -46,6 +48,7 @@
"CropForeground",
"RandWeightedCrop",
"RandCropByPosNegLabel",
"RandCropByLabelClasses",
"ResizeWithPadOrCrop",
"BoundingRect",
]
Expand Down Expand Up @@ -766,7 +769,11 @@ def randomize(
) -> None:
self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:])
if fg_indices is None or bg_indices is None:
fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold)
if self.fg_indices is not None and self.bg_indices is not None:
fg_indices_ = self.fg_indices
bg_indices_ = self.bg_indices
else:
fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold)
else:
fg_indices_ = fg_indices
bg_indices_ = bg_indices
Expand Down Expand Up @@ -802,12 +809,7 @@ def __call__(
raise ValueError("label should be provided.")
if image is None:
image = self.image
if fg_indices is None or bg_indices is None:
if self.fg_indices is not None and self.bg_indices is not None:
fg_indices = self.fg_indices
bg_indices = self.bg_indices
else:
fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold)

self.randomize(label, fg_indices, bg_indices, image)
results: List[np.ndarray] = []
if self.centers is not None:
Expand All @@ -818,6 +820,139 @@ def __call__(
return results


class RandCropByLabelClasses(Randomizable, Transform):
"""
Crop random fixed sized regions with the center being a class based on the specified ratios of every class.
The label data can be One-Hot format array or Argmax data. And will return a list of arrays for all the
cropped images. For example, crop two (3 x 3) arrays from (5 x 5) array with `ratios=[1, 2, 3, 1]`::

image = np.array([
[[0.0, 0.3, 0.4, 0.2, 0.0],
[0.0, 0.1, 0.2, 0.1, 0.4],
[0.0, 0.3, 0.5, 0.2, 0.0],
[0.1, 0.2, 0.1, 0.1, 0.0],
[0.0, 0.1, 0.2, 0.1, 0.0]]
])
label = np.array([
[[0, 0, 0, 0, 0],
[0, 1, 2, 1, 0],
[0, 1, 3, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]
])
cropper = RandCropByLabelClasses(
spatial_size=[3, 3],
ratios=[1, 2, 3, 1],
num_classes=4,
num_samples=2,
)
label_samples = cropper(img=label, label=label, image=image)

The 2 randomly cropped samples of `label` can be:
[[0, 1, 2], [[0, 0, 0],
[0, 1, 3], [1, 2, 1],
[0, 0, 0]] [1, 3, 0]]

If a dimension of the expected spatial size is bigger than the input image size,
will not crop that dimension. So the cropped result may be smaller than expected size, and the cropped
results of several images may not have exactly same shape.

Args:
spatial_size: the spatial size of the crop region e.g. [224, 224, 128].
if a dimension of ROI size is bigger than image size, will not crop that dimension of the image.
if its components have non-positive values, the corresponding size of `label` will be used.
for example: if the spatial size of input data is [40, 40, 40] and `spatial_size=[32, 64, -1]`,
the spatial size of output data will be [32, 40, 40].
ratios: specified ratios of every class in the label to generate crop centers, including background class.
if None, every class will have the same ratio to generate crop centers.
label: the label image that is used for finding every classes, if None, must set at `self.__call__`.
num_classes: number of classes for argmax label, not necessary for One-Hot label.
num_samples: number of samples (crop regions) to take in each list.
image: if image is not None, only return the indices of every class that are within the valid
region of the image (``image > image_threshold``).
image_threshold: if enabled `image`, use ``image > image_threshold`` to
determine the valid image content area and select class indices only in this area.
indices: if provided pre-computed indices of every class, will ignore above `image` and
`image_threshold`, and randomly select crop centers based on them, expect to be 1 dim array
of spatial indices after flattening. a typical usage is to call `ClassesToIndices` transform first
and cache the results for better performance.

"""

def __init__(
self,
spatial_size: Union[Sequence[int], int],
ratios: Optional[List[Union[float, int]]] = None,
label: Optional[np.ndarray] = None,
num_classes: Optional[int] = None,
num_samples: int = 1,
image: Optional[np.ndarray] = None,
image_threshold: float = 0.0,
indices: Optional[List[np.ndarray]] = None,
) -> None:
self.spatial_size = ensure_tuple(spatial_size)
self.ratios = ratios
self.label = label
self.num_classes = num_classes
self.num_samples = num_samples
self.image = image
self.image_threshold = image_threshold
self.centers: Optional[List[List[np.ndarray]]] = None
self.indices = indices

def randomize(
self,
label: np.ndarray,
indices: Optional[List[np.ndarray]] = None,
image: Optional[np.ndarray] = None,
) -> None:
self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:])
indices_: List[np.ndarray]
if indices is None:
if self.indices is not None:
indices_ = self.indices
else:
indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold)
else:
indices_ = indices
self.centers = generate_label_classes_crop_centers(
self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R
)

def __call__(
self,
img: np.ndarray,
label: Optional[np.ndarray] = None,
image: Optional[np.ndarray] = None,
indices: Optional[List[np.ndarray]] = None,
) -> List[np.ndarray]:
"""
Args:
img: input data to crop samples from based on the ratios of every class, assumes `img` is a
channel-first array.
label: the label image that is used for finding indices of every class, if None, use `self.label`.
image: optional image data to help select valid area, can be same as `img` or another image array.
use ``image > image_threshold`` to select the centers only in valid region. if None, use `self.image`.
indices: list of indices for every class in the image, used to randomly select crop centers.

"""
if label is None:
label = self.label
if label is None:
raise ValueError("label should be provided.")
if image is None:
image = self.image

self.randomize(label, indices, image)
results: List[np.ndarray] = []
if self.centers is not None:
for center in self.centers:
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore
results.append(cropper(img))

return results


class ResizeWithPadOrCrop(Transform):
"""
Resize an image to a target spatial size by either centrally cropping the image or
Expand Down
Loading