diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index ebb248387f..e659c7ebc0 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1353,7 +1353,7 @@ class RandHistogramShift(RandomizableTransform): prob: probability of histogram shift. """ - backend = [TransformBackends.NUMPY] + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__(self, num_control_points: Union[Tuple[int, int], int] = 10, prob: float = 0.1) -> None: RandomizableTransform.__init__(self, prob) @@ -1368,8 +1368,25 @@ def __init__(self, num_control_points: Union[Tuple[int, int], int] = 10, prob: f if min(num_control_points) <= 2: raise ValueError("num_control_points should be greater than or equal to 3") self.num_control_points = (min(num_control_points), max(num_control_points)) - self.reference_control_points: np.ndarray - self.floating_control_points: np.ndarray + self.reference_control_points: NdarrayOrTensor + self.floating_control_points: NdarrayOrTensor + + def interp(self, x: NdarrayOrTensor, xp: NdarrayOrTensor, fp: NdarrayOrTensor) -> NdarrayOrTensor: + ns = torch if isinstance(x, torch.Tensor) else np + if isinstance(x, np.ndarray): + # approx 2x faster than code below for ndarray + return np.interp(x, xp, fp) + + m = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1]) + b = fp[:-1] - (m * xp[:-1]) + + indices = ns.searchsorted(xp.reshape(-1), x.reshape(-1)) - 1 + indices = ns.clip(indices, 0, len(m) - 1) + + f = (m[indices] * x.reshape(-1) + b[indices]).reshape(x.shape) + f[x < xp[0]] = fp[0] # type: ignore + f[x > xp[-1]] = fp[-1] # type: ignore + return f def randomize(self, data: Optional[Any] = None) -> None: super().randomize(None) @@ -1392,14 +1409,13 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen if self.reference_control_points is None or self.floating_control_points is None: raise RuntimeError("please call the `randomize()` function first.") - img_np, *_ = convert_data_type(img, np.ndarray) - img_min, img_max = img_np.min(), img_np.max() - reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min - floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min - img_np = np.asarray( # type: ignore - np.interp(img_np, reference_control_points_scaled, floating_control_points_scaled), dtype=img_np.dtype - ) - img, *_ = convert_to_dst_type(img_np, dst=img) + + xp, *_ = convert_to_dst_type(self.reference_control_points, dst=img) + yp, *_ = convert_to_dst_type(self.floating_control_points, dst=img) + img_min, img_max = img.min(), img.max() + reference_control_points_scaled = xp * (img_max - img_min) + img_min + floating_control_points_scaled = yp * (img_max - img_min) + img_min + img = self.interp(img, reference_control_points_scaled, floating_control_points_scaled) return img diff --git a/tests/test_rand_histogram_shift.py b/tests/test_rand_histogram_shift.py index c66f7859c6..0682306bb6 100644 --- a/tests/test_rand_histogram_shift.py +++ b/tests/test_rand_histogram_shift.py @@ -12,6 +12,7 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandHistogramShift @@ -50,6 +51,24 @@ def test_rand_histogram_shift(self, input_param, input_data, expected_val): result = g(**input_data) assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4, type_test=False) + def test_interp(self): + tr = RandHistogramShift() + for array_type in (torch.tensor, np.array): + x = array_type([0.0, 4.0, 6.0, 10.0]) + y = array_type([1.0, -1.0, 3.0, 5.0]) + + yi = tr.interp(array_type([0, 2, 4, 8, 10]), x, y) + assert yi.shape == (5,) + assert_allclose(yi, array_type([1.0, 0.0, -1.0, 4.0, 5.0])) + + yi = tr.interp(array_type([-1, 11, 10.001, -0.001]), x, y) + assert yi.shape == (4,) + assert_allclose(yi, array_type([1.0, 5.0, 5.0, 1.0])) + + yi = tr.interp(array_type([[-2, 11], [1, 3], [8, 10]]), x, y) + assert yi.shape == (3, 2) + assert_allclose(yi, array_type([[1.0, 5.0], [0.5, -0.5], [4.0, 5.0]])) + if __name__ == "__main__": unittest.main()