diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index c904631bed..c2cf03d6fe 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -693,6 +693,8 @@ def threshold_at_one(x): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__( self, select_fn: Callable = is_positive, @@ -728,13 +730,15 @@ def __init__( self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode) self.np_kwargs = np_kwargs - def compute_bounding_box(self, img: np.ndarray): + def compute_bounding_box(self, img: NdarrayOrTensor) -> Tuple[np.ndarray, np.ndarray]: """ Compute the start points and end points of bounding box to crop. And adjust bounding box coords to be divisible by `k`. """ box_start, box_end = generate_spatial_bounding_box(img, self.select_fn, self.channel_indices, self.margin) + box_start = [i.cpu() if isinstance(i, torch.Tensor) else i for i in box_start] # type: ignore + box_end = [i.cpu() if isinstance(i, torch.Tensor) else i for i in box_end] # type: ignore box_start_ = np.asarray(box_start, dtype=np.int16) box_end_ = np.asarray(box_end, dtype=np.int16) orig_spatial_size = box_end_ - box_start_ @@ -747,7 +751,7 @@ def compute_bounding_box(self, img: np.ndarray): def crop_pad( self, - img: np.ndarray, + img: NdarrayOrTensor, box_start: np.ndarray, box_end: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None, @@ -762,13 +766,11 @@ def crop_pad( pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist()))) return BorderPad(spatial_border=pad, mode=mode or self.mode, **self.np_kwargs)(cropped) - def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None): + def __call__(self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, str]] = None): """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't change the channel dim. """ - img, *_ = convert_data_type(img, np.ndarray) # type: ignore - box_start, box_end = self.compute_bounding_box(img) cropped = self.crop_pad(img, box_start, box_end, mode) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 86a865c237..376c2c811f 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1249,6 +1249,8 @@ class RandHistogramShift(RandomizableTransform): prob: probability of histogram shift. """ + backend = [TransformBackends.NUMPY] + def __init__(self, num_control_points: Union[Tuple[int, int], int] = 10, prob: float = 0.1) -> None: RandomizableTransform.__init__(self, prob) @@ -1273,16 +1275,17 @@ def randomize(self, data: Optional[Any] = None) -> None: self.floating_control_points[i - 1], self.floating_control_points[i + 1] ) - def __call__(self, img: np.ndarray) -> np.ndarray: - img, *_ = convert_data_type(img, np.ndarray) # type: ignore + def __call__(self, img: NdarrayOrTensor) -> np.ndarray: + img_np: np.ndarray + img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore self.randomize() if not self._do_transform: - return img - img_min, img_max = img.min(), img.max() + return img_np + img_min, img_max = img_np.min(), img_np.max() reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min return np.asarray( - np.interp(img, reference_control_points_scaled, floating_control_points_scaled), dtype=img.dtype + np.interp(img_np, reference_control_points_scaled, floating_control_points_scaled), dtype=img_np.dtype ) @@ -1902,12 +1905,14 @@ class HistogramNormalize(Transform): """ + backend = [TransformBackends.NUMPY] + def __init__( self, num_bins: int = 256, min: int = 0, max: int = 255, - mask: Optional[np.ndarray] = None, + mask: Optional[NdarrayOrTensor] = None, dtype: DtypeLike = np.float32, ) -> None: self.num_bins = num_bins @@ -1916,8 +1921,7 @@ def __init__( self.mask = mask self.dtype = dtype - def __call__(self, img: np.ndarray, mask: Optional[np.ndarray] = None) -> np.ndarray: - img, *_ = convert_data_type(img, np.ndarray) # type: ignore + def __call__(self, img: NdarrayOrTensor, mask: Optional[NdarrayOrTensor] = None) -> np.ndarray: return equalize_hist( img=img, mask=mask if mask is not None else self.mask, diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index dbbfe3cba0..8681093168 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -49,6 +49,7 @@ from monai.transforms.utils import is_positive from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep, ensure_tuple_size from monai.utils.deprecated import deprecated_arg +from monai.utils.enums import TransformBackends from monai.utils.type_conversion import convert_data_type __all__ = [ @@ -1108,6 +1109,8 @@ class RandHistogramShiftd(RandomizableTransform, MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = [TransformBackends.NUMPY] + def __init__( self, keys: KeysCollection, @@ -1138,17 +1141,19 @@ def randomize(self, data: Optional[Any] = None) -> None: self.floating_control_points[i - 1], self.floating_control_points[i + 1] ) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) self.randomize() - if not self._do_transform: - return d for key in self.key_iterator(d): - img_min, img_max = d[key].min(), d[key].max() - reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min - floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min - dtype = d[key].dtype - d[key] = np.interp(d[key], reference_control_points_scaled, floating_control_points_scaled).astype(dtype) + d[key] = convert_data_type(d[key], np.ndarray)[0] + if self._do_transform: + img_min, img_max = d[key].min(), d[key].max() + reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min + floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min + dtype = d[key].dtype + d[key] = np.interp(d[key], reference_control_points_scaled, floating_control_points_scaled).astype( + dtype + ) return d @@ -1594,13 +1599,15 @@ class HistogramNormalized(MapTransform): """ + backend = HistogramNormalize.backend + def __init__( self, keys: KeysCollection, num_bins: int = 256, min: int = 0, max: int = 255, - mask: Optional[np.ndarray] = None, + mask: Optional[NdarrayOrTensor] = None, mask_key: Optional[str] = None, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, @@ -1609,7 +1616,7 @@ def __init__( self.transform = HistogramNormalize(num_bins=num_bins, min=min, max=max, mask=mask, dtype=dtype) self.mask_key = mask_key if mask is None else None - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.transform(d[key], d[self.mask_key]) if self.mask_key is not None else self.transform(d[key]) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 67fdb70a8c..15543e91ef 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -25,7 +25,7 @@ from monai.networks.layers import GaussianFilter from monai.transforms.compose import Compose, OneOf from monai.transforms.transform import MapTransform, Transform, apply_transform -from monai.transforms.utils_pytorch_numpy_unification import any_np_pt, nonzero, ravel, unravel_index +from monai.transforms.utils_pytorch_numpy_unification import any_np_pt, nonzero, ravel, unravel_index, where from monai.utils import ( GridSampleMode, InterpolateMode, @@ -839,7 +839,7 @@ def _create_translate( def generate_spatial_bounding_box( - img: np.ndarray, + img: NdarrayOrTensor, select_fn: Callable = is_positive, channel_indices: Optional[IndexSelection] = None, margin: Union[Sequence[int], int] = 0, @@ -864,7 +864,7 @@ def generate_spatial_bounding_box( margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. """ data = img[list(ensure_tuple(channel_indices))] if channel_indices is not None else img - data = np.any(select_fn(data), axis=0) + data = select_fn(data).any(0) ndim = len(data.shape) margin = ensure_tuple_rep(margin, ndim) for m in margin: @@ -875,13 +875,18 @@ def generate_spatial_bounding_box( box_end = [0] * ndim for di, ax in enumerate(itertools.combinations(reversed(range(ndim)), ndim - 1)): - dt = data.any(axis=ax) - if not np.any(dt): + dt = data + if len(ax) != 0: + dt = any_np_pt(dt, ax) + + if not dt.any(): # if no foreground, return all zero bounding box coords return [0] * ndim, [0] * ndim - min_d = max(np.argmax(dt) - margin[di], 0) - max_d = max(data.shape[di] - max(np.argmax(dt[::-1]) - margin[di], 0), min_d + 1) + arg_max = where(dt == dt.max())[0] + min_d = max(arg_max[0] - margin[di], 0) + max_d = arg_max[-1] + margin[di] + 1 + box_start[di], box_end[di] = min_d, max_d return box_start, box_end @@ -1203,8 +1208,8 @@ def compute_divisible_spatial_size(spatial_shape: Sequence[int], k: Union[Sequen def equalize_hist( - img: np.ndarray, - mask: Optional[np.ndarray] = None, + img: NdarrayOrTensor, + mask: Optional[NdarrayOrTensor] = None, num_bins: int = 256, min: int = 0, max: int = 255, @@ -1226,8 +1231,14 @@ def equalize_hist( dtype: data type of the output, default to `float32`. """ - orig_shape = img.shape - hist_img = img[np.array(mask, dtype=bool)] if mask is not None else img + img_np: np.ndarray + img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore + mask_np: Optional[np.ndarray] = None + if mask is not None: + mask_np, *_ = convert_data_type(mask, np.ndarray) # type: ignore + + orig_shape = img_np.shape + hist_img = img_np[np.array(mask_np, dtype=bool)] if mask_np is not None else img_np if has_skimage: hist, bins = exposure.histogram(hist_img.flatten(), num_bins) else: @@ -1239,9 +1250,9 @@ def equalize_hist( cum = rescale_array(arr=cum, minv=min, maxv=max) # apply linear interpolation - img = np.interp(img.flatten(), bins, cum) + img_np = np.interp(img_np.flatten(), bins, cum) - return img.reshape(orig_shape).astype(dtype) + return img_np.reshape(orig_shape).astype(dtype) class Fourier: diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 2c370a4707..52c5248f96 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Sequence, Union import numpy as np import torch @@ -114,17 +114,23 @@ def percentile(x: NdarrayOrTensor, q) -> Union[NdarrayOrTensor, float, int]: return result -def where(condition: NdarrayOrTensor, x, y) -> NdarrayOrTensor: +def where(condition: NdarrayOrTensor, x=None, y=None) -> NdarrayOrTensor: """ Note that `torch.where` may convert y.dtype to x.dtype. """ result: NdarrayOrTensor if isinstance(condition, np.ndarray): - result = np.where(condition, x, y) + if x is not None: + result = np.where(condition, x, y) + else: + result = np.where(condition) else: - x = torch.as_tensor(x, device=condition.device) - y = torch.as_tensor(y, device=condition.device, dtype=x.dtype) - result = torch.where(condition, x, y) + if x is not None: + x = torch.as_tensor(x, device=condition.device) + y = torch.as_tensor(y, device=condition.device, dtype=x.dtype) + result = torch.where(condition, x, y) + else: + result = torch.where(condition) # type: ignore return result @@ -211,7 +217,7 @@ def ravel(x: NdarrayOrTensor): return np.ravel(x) -def any_np_pt(x: NdarrayOrTensor, axis: int): +def any_np_pt(x: NdarrayOrTensor, axis: Union[int, Sequence[int]]): """`np.any` with equivalent implementation for torch. For pytorch, convert to boolean for compatibility with older versions. @@ -223,13 +229,18 @@ def any_np_pt(x: NdarrayOrTensor, axis: int): Returns: Return a contiguous flattened array/tensor. """ - if isinstance(x, torch.Tensor): + if isinstance(x, np.ndarray): + return np.any(x, axis) + + # pytorch can't handle multiple dimensions to `any` so loop across them + axis = [axis] if not isinstance(axis, Sequence) else axis + for ax in axis: try: - return torch.any(x, axis) + x = torch.any(x, ax) except RuntimeError: # older versions of pytorch require the input to be cast to boolean - return torch.any(x.bool(), axis) - return np.any(x, axis) + x = torch.any(x.bool(), ax) + return x def maximum(a: NdarrayOrTensor, b: NdarrayOrTensor) -> NdarrayOrTensor: diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index 71e488cac8..0bae1f90f3 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -12,60 +12,79 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import CropForeground +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] - -TEST_CASE_2 = [ - {"select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[3]]]), -] - -TEST_CASE_3 = [ - {"select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] - -TEST_CASE_4 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), -] - -TEST_CASE_5 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), -] - -TEST_CASE_6 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 4}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - np.array([[[1, 2, 1, 0], [2, 3, 2, 0], [1, 2, 1, 0], [0, 0, 0, 0]]]), -] - -TEST_CASE_7 = [ - {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 10, "constant_values": 2}, - np.array([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - np.zeros((1, 0, 0)), -] +TEST_COORDS, TESTS = [], [] + +for p in TEST_NDARRAYS: + TEST_COORDS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + p([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, + p([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), + p([[[3]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + p([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 4}, + p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), + p([[[1, 2, 1, 0], [2, 3, 2, 0], [1, 2, 1, 0], [0, 0, 0, 0]]]), + ] + ) + + TESTS.append( + [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": 0, "k_divisible": 10}, + p([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + p(np.zeros((1, 0, 0), dtype=np.int64)), + ] + ) class TestCropForeground(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + @parameterized.expand(TEST_COORDS + TESTS) def test_value(self, argments, image, expected_data): result = CropForeground(**argments)(image) - np.testing.assert_allclose(result, expected_data) + torch.testing.assert_allclose(result, expected_data, rtol=1e-7, atol=0) - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TEST_COORDS) def test_return_coords(self, argments, image, _): argments["return_coords"] = True _, start_coord, end_coord = CropForeground(**argments)(image) diff --git a/tests/test_crop_foregroundd.py b/tests/test_crop_foregroundd.py index efe6b65b4b..5fa474d6ac 100644 --- a/tests/test_crop_foregroundd.py +++ b/tests/test_crop_foregroundd.py @@ -12,85 +12,128 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import CropForegroundd -from monai.utils import NumpyPadMode +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - { - "keys": ["img", "label"], - "source_key": "label", - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": 0, - "mode": "constant", - "constant_values": 2, - }, - { - "img": np.array([[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]), - "label": np.array([[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]), - }, - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] +TEST_POSITION, TESTS = [], [] +for p in TEST_NDARRAYS: -TEST_CASE_2 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[3]]]), -] - -TEST_CASE_3 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]), -] - -TEST_CASE_4 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), -] - -TEST_CASE_5 = [ - {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]}, - {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]])}, - np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), -] - -TEST_CASE_6 = [ - { - "keys": ["img", "seg"], - "source_key": "img", - "select_fn": lambda x: x > 0, - "channel_indices": 0, - "margin": 0, - "k_divisible": [4, 6], - "mode": ["edge", NumpyPadMode.CONSTANT], - }, - { - "img": np.array([[[0, 2, 1, 2, 0], [1, 1, 2, 1, 1], [2, 2, 3, 2, 2], [1, 1, 2, 1, 1], [0, 0, 0, 0, 0]]]), - "seg": np.array([[[0, 2, 1, 2, 0], [1, 1, 2, 1, 1], [2, 2, 3, 2, 2], [1, 1, 2, 1, 1], [0, 0, 0, 0, 0]]]), - }, - np.array([[[0, 2, 1, 2, 0, 0], [1, 1, 2, 1, 1, 1], [2, 2, 3, 2, 2, 2], [1, 1, 2, 1, 1, 1]]]), -] + TEST_POSITION.append( + [ + { + "keys": ["img", "label"], + "source_key": "label", + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 0, + }, + { + "img": p( + np.array([[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]) + ), + "label": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]) + ), + }, + p(np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]])), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 1, "channel_indices": None, "margin": 0}, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]) + ) + }, + p(np.array([[[3]]])), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": 0, "margin": 0}, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ) + }, + p(np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]])), + ] + ) + TESTS.append( + [ + {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": None, "margin": 1}, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]) + ) + }, + p(np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]])), + ] + ) + TESTS.append( + [ + { + "keys": ["img"], + "source_key": "img", + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": [2, 1], + }, + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]) + ) + }, + p(np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]])), + ] + ) + TESTS.append( + [ + { + "keys": ["img"], + "source_key": "img", + "select_fn": lambda x: x > 0, + "channel_indices": 0, + "margin": 0, + "k_divisible": [4, 6], + "mode": "edge", + }, + { + "img": p( + np.array( + [[[0, 2, 1, 2, 0], [1, 1, 2, 1, 1], [2, 2, 3, 2, 2], [1, 1, 2, 1, 1], [0, 0, 0, 0, 0]]], + dtype=np.float32, + ) + ) + }, + p(np.array([[[0, 2, 1, 2, 0, 0], [1, 1, 2, 1, 1, 1], [2, 2, 3, 2, 2, 2], [1, 1, 2, 1, 1, 1]]])), + ] + ) class TestCropForegroundd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) - def test_value(self, argments, image, expected_data): - result = CropForegroundd(**argments)(image) - np.testing.assert_allclose(result["img"], expected_data) + @parameterized.expand(TEST_POSITION + TESTS) + def test_value(self, argments, input_data, expected_data): + result = CropForegroundd(**argments)(input_data) + r, i = result["img"], input_data["img"] + self.assertEqual(type(r), type(i)) + if isinstance(r, torch.Tensor): + self.assertEqual(r.device, i.device) + assert_allclose(r, expected_data) - @parameterized.expand([TEST_CASE_1]) - def test_foreground_position(self, argments, image, _): - result = CropForegroundd(**argments)(image) + @parameterized.expand(TEST_POSITION) + def test_foreground_position(self, argments, input_data, _): + result = CropForegroundd(**argments)(input_data) np.testing.assert_allclose(result["foreground_start_coord"], np.array([1, 1])) np.testing.assert_allclose(result["foreground_end_coord"], np.array([4, 4])) argments["start_coord_key"] = "test_start_coord" argments["end_coord_key"] = "test_end_coord" - result = CropForegroundd(**argments)(image) + result = CropForegroundd(**argments)(input_data) np.testing.assert_allclose(result["test_start_coord"], np.array([1, 1])) np.testing.assert_allclose(result["test_end_coord"], np.array([4, 4])) diff --git a/tests/test_generate_spatial_bounding_box.py b/tests/test_generate_spatial_bounding_box.py index 32a45d8d1c..d73b9fafcc 100644 --- a/tests/test_generate_spatial_bounding_box.py +++ b/tests/test_generate_spatial_bounding_box.py @@ -15,60 +15,79 @@ from parameterized import parameterized from monai.transforms import generate_spatial_bounding_box +from tests.utils import TEST_NDARRAYS -TEST_CASE_1 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": 0, - }, - ([1, 1], [4, 4]), -] - -TEST_CASE_2 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 1, - "channel_indices": None, - "margin": 0, - }, - ([2, 2], [3, 3]), -] - -TEST_CASE_3 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": 0, - "margin": 0, - }, - ([1, 1], [4, 4]), -] - -TEST_CASE_4 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": 1, - }, - ([0, 0], [4, 5]), -] - -TEST_CASE_5 = [ - { - "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), - "select_fn": lambda x: x > 0, - "channel_indices": None, - "margin": [2, 1], - }, - ([0, 0], [5, 5]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 0, + }, + ([1, 1], [4, 4]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 1, + "channel_indices": None, + "margin": 0, + }, + ([2, 2], [3, 3]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": 0, + "margin": 0, + }, + ([1, 1], [4, 4]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": 1, + }, + ([0, 0], [4, 5]), + ] + ) + TESTS.append( + [ + { + "img": p( + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]) + ), + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": [2, 1], + }, + ([0, 0], [5, 5]), + ] + ) class TestGenerateSpatialBoundingBox(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand(TESTS) def test_value(self, input_data, expected_box): result = generate_spatial_bounding_box(**input_data) self.assertTupleEqual(result, expected_box) diff --git a/tests/test_histogram_normalize.py b/tests/test_histogram_normalize.py index b69fb1d927..e0178166d9 100644 --- a/tests/test_histogram_normalize.py +++ b/tests/test_histogram_normalize.py @@ -15,28 +15,37 @@ from parameterized import parameterized from monai.transforms import HistogramNormalize - -TEST_CASE_1 = [ - {"num_bins": 4, "min": 1, "max": 5, "mask": np.array([1, 1, 1, 1, 1, 0])}, - np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), - np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0]), -] - -TEST_CASE_2 = [ - {"num_bins": 4, "max": 4, "dtype": np.uint8}, - np.array([0.0, 1.0, 2.0, 3.0, 4.0]), - np.array([0, 0, 1, 3, 4]), -] - -TEST_CASE_3 = [ - {"num_bins": 256, "max": 255, "dtype": np.uint8}, - np.array([[[100.0, 200.0], [150.0, 250.0]]]), - np.array([[[0, 170], [70, 255]]]), -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"num_bins": 4, "min": 1, "max": 5, "mask": np.array([1, 1, 1, 1, 1, 0])}, + p(np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])), + np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0]), + ] + ) + + TESTS.append( + [ + {"num_bins": 4, "max": 4, "dtype": np.uint8}, + p(np.array([0.0, 1.0, 2.0, 3.0, 4.0])), + np.array([0, 0, 1, 3, 4]), + ] + ) + + TESTS.append( + [ + {"num_bins": 256, "max": 255, "dtype": np.uint8}, + p(np.array([[[100.0, 200.0], [150.0, 250.0]]])), + np.array([[[0, 170], [70, 255]]]), + ] + ) class TestHistogramNormalize(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = HistogramNormalize(**argments)(image) np.testing.assert_allclose(result, expected_data) diff --git a/tests/test_histogram_normalized.py b/tests/test_histogram_normalized.py index 68647e82fb..314c7bd75b 100644 --- a/tests/test_histogram_normalized.py +++ b/tests/test_histogram_normalized.py @@ -15,28 +15,37 @@ from parameterized import parameterized from monai.transforms import HistogramNormalized - -TEST_CASE_1 = [ - {"keys": "img", "num_bins": 4, "min": 1, "max": 5, "mask_key": "mask"}, - {"img": np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), "mask": np.array([1, 1, 1, 1, 1, 0])}, - np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0]), -] - -TEST_CASE_2 = [ - {"keys": "img", "num_bins": 4, "max": 4, "dtype": np.uint8}, - {"img": np.array([0.0, 1.0, 2.0, 3.0, 4.0])}, - np.array([0, 0, 1, 3, 4]), -] - -TEST_CASE_3 = [ - {"keys": "img", "num_bins": 256, "max": 255, "dtype": np.uint8}, - {"img": np.array([[[100.0, 200.0], [150.0, 250.0]]])}, - np.array([[[0, 170], [70, 255]]]), -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": "img", "num_bins": 4, "min": 1, "max": 5, "mask_key": "mask"}, + {"img": p(np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])), "mask": p(np.array([1, 1, 1, 1, 1, 0]))}, + np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0]), + ] + ) + + TESTS.append( + [ + {"keys": "img", "num_bins": 4, "max": 4, "dtype": np.uint8}, + {"img": p(np.array([0.0, 1.0, 2.0, 3.0, 4.0]))}, + np.array([0, 0, 1, 3, 4]), + ] + ) + + TESTS.append( + [ + {"keys": "img", "num_bins": 256, "max": 255, "dtype": np.uint8}, + {"img": p(np.array([[[100.0, 200.0], [150.0, 250.0]]]))}, + np.array([[[0, 170], [70, 255]]]), + ] + ) class TestHistogramNormalized(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = HistogramNormalized(**argments)(image)["img"] np.testing.assert_allclose(result, expected_data) diff --git a/tests/test_rand_histogram_shift.py b/tests/test_rand_histogram_shift.py index b258cc5a7e..fa51dacefa 100644 --- a/tests/test_rand_histogram_shift.py +++ b/tests/test_rand_histogram_shift.py @@ -15,28 +15,35 @@ from parameterized import parameterized from monai.transforms import RandHistogramShift - -TEST_CASES = [ - [ - {"num_control_points": 5, "prob": 0.0}, - {"img": np.arange(8).reshape((1, 2, 2, 2))}, - np.arange(8).reshape((1, 2, 2, 2)), - ], - [ - {"num_control_points": 5, "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)}, - np.array([[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]]), - ], - [ - {"num_control_points": (5, 20), "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)}, - np.array([[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]]), - ], -] +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"num_control_points": 5, "prob": 0.0}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2)))}, + np.arange(8).reshape((1, 2, 2, 2)), + ] + ) + TESTS.append( + [ + {"num_control_points": 5, "prob": 0.9}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32))}, + np.array([[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]]), + ] + ) + TESTS.append( + [ + {"num_control_points": (5, 20), "prob": 0.9}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32))}, + np.array([[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]]), + ] + ) class TestRandHistogramShift(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_histogram_shift(self, input_param, input_data, expected_val): g = RandHistogramShift(**input_param) g.set_random_state(123) diff --git a/tests/test_rand_histogram_shiftd.py b/tests/test_rand_histogram_shiftd.py index 806e4f5cf2..2191e99518 100644 --- a/tests/test_rand_histogram_shiftd.py +++ b/tests/test_rand_histogram_shiftd.py @@ -14,47 +14,60 @@ import numpy as np from parameterized import parameterized -from monai.transforms import RandHistogramShiftD - -TEST_CASES = [ - [ - {"keys": ("img",), "num_control_points": 5, "prob": 0.0}, - {"img": np.arange(8).reshape((1, 2, 2, 2)), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - {"img": np.arange(8).reshape((1, 2, 2, 2)), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - ], - [ - {"keys": ("img",), "num_control_points": 5, "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - { - "img": np.array( - [[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]] - ), - "seg": np.ones(8).reshape((1, 2, 2, 2)), - }, - ], - [ - {"keys": ("img",), "num_control_points": (5, 20), "prob": 0.9}, - {"img": np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32), "seg": np.ones(8).reshape((1, 2, 2, 2))}, - { - "img": np.array( - [[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]] - ), - "seg": np.ones(8).reshape((1, 2, 2, 2)), - }, - ], -] +from monai.transforms.intensity.dictionary import RandHistogramShiftd +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"keys": ("img",), "num_control_points": 5, "prob": 0.0}, + {"img": p(np.arange(8).reshape((1, 2, 2, 2))), "seg": p(np.ones(8).reshape((1, 2, 2, 2)))}, + {"img": np.arange(8).reshape((1, 2, 2, 2)), "seg": np.ones(8).reshape((1, 2, 2, 2))}, + ] + ) + TESTS.append( + [ + {"keys": ("img",), "num_control_points": 5, "prob": 0.9}, + { + "img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)), + "seg": p(np.ones(8).reshape((1, 2, 2, 2))), + }, + { + "img": np.array( + [[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]] + ), + "seg": np.ones(8).reshape((1, 2, 2, 2)), + }, + ] + ) + TESTS.append( + [ + {"keys": ("img",), "num_control_points": (5, 20), "prob": 0.9}, + { + "img": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)), + "seg": p(np.ones(8).reshape((1, 2, 2, 2))), + }, + { + "img": np.array( + [[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]] + ), + "seg": np.ones(8).reshape((1, 2, 2, 2)), + }, + ] + ) class TestRandHistogramShiftD(unittest.TestCase): - @parameterized.expand(TEST_CASES) + @parameterized.expand(TESTS) def test_rand_histogram_shiftd(self, input_param, input_data, expected_val): - g = RandHistogramShiftD(**input_param) + g = RandHistogramShiftd(**input_param) g.set_random_state(123) res = g(input_data) for key in res: result = res[key] expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - np.testing.assert_allclose(result, expected, rtol=1e-4, atol=1e-4) + assert_allclose(result, expected, rtol=1e-4, atol=1e-4, type_test=False) if __name__ == "__main__":