Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,7 +1353,7 @@ class RandHistogramShift(RandomizableTransform):
prob: probability of histogram shift.
"""

backend = [TransformBackends.NUMPY]
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, num_control_points: Union[Tuple[int, int], int] = 10, prob: float = 0.1) -> None:
RandomizableTransform.__init__(self, prob)
Expand All @@ -1368,8 +1368,25 @@ def __init__(self, num_control_points: Union[Tuple[int, int], int] = 10, prob: f
if min(num_control_points) <= 2:
raise ValueError("num_control_points should be greater than or equal to 3")
self.num_control_points = (min(num_control_points), max(num_control_points))
self.reference_control_points: np.ndarray
self.floating_control_points: np.ndarray
self.reference_control_points: NdarrayOrTensor
self.floating_control_points: NdarrayOrTensor

def interp(self, x: NdarrayOrTensor, xp: NdarrayOrTensor, fp: NdarrayOrTensor) -> NdarrayOrTensor:
ns = torch if isinstance(x, torch.Tensor) else np
if isinstance(x, np.ndarray):
# approx 2x faster than code below for ndarray
return np.interp(x, xp, fp)

m = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1])
b = fp[:-1] - (m * xp[:-1])

indices = ns.searchsorted(xp.reshape(-1), x.reshape(-1)) - 1
indices = ns.clip(indices, 0, len(m) - 1)

f = (m[indices] * x.reshape(-1) + b[indices]).reshape(x.shape)
f[x < xp[0]] = fp[0] # type: ignore
f[x > xp[-1]] = fp[-1] # type: ignore
return f

def randomize(self, data: Optional[Any] = None) -> None:
super().randomize(None)
Expand All @@ -1392,14 +1409,13 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen

if self.reference_control_points is None or self.floating_control_points is None:
raise RuntimeError("please call the `randomize()` function first.")
img_np, *_ = convert_data_type(img, np.ndarray)
img_min, img_max = img_np.min(), img_np.max()
reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min
floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min
img_np = np.asarray( # type: ignore
np.interp(img_np, reference_control_points_scaled, floating_control_points_scaled), dtype=img_np.dtype
)
img, *_ = convert_to_dst_type(img_np, dst=img)

xp, *_ = convert_to_dst_type(self.reference_control_points, dst=img)
yp, *_ = convert_to_dst_type(self.floating_control_points, dst=img)
img_min, img_max = img.min(), img.max()
reference_control_points_scaled = xp * (img_max - img_min) + img_min
floating_control_points_scaled = yp * (img_max - img_min) + img_min
img = self.interp(img, reference_control_points_scaled, floating_control_points_scaled)
return img


Expand Down
19 changes: 19 additions & 0 deletions tests/test_rand_histogram_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms import RandHistogramShift
Expand Down Expand Up @@ -50,6 +51,24 @@ def test_rand_histogram_shift(self, input_param, input_data, expected_val):
result = g(**input_data)
assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4, type_test=False)

def test_interp(self):
tr = RandHistogramShift()
for array_type in (torch.tensor, np.array):
x = array_type([0.0, 4.0, 6.0, 10.0])
y = array_type([1.0, -1.0, 3.0, 5.0])

yi = tr.interp(array_type([0, 2, 4, 8, 10]), x, y)
assert yi.shape == (5,)
assert_allclose(yi, array_type([1.0, 0.0, -1.0, 4.0, 5.0]))

yi = tr.interp(array_type([-1, 11, 10.001, -0.001]), x, y)
assert yi.shape == (4,)
assert_allclose(yi, array_type([1.0, 5.0, 5.0, 1.0]))

yi = tr.interp(array_type([[-2, 11], [1, 3], [8, 10]]), x, y)
assert yi.shape == (3, 2)
assert_allclose(yi, array_type([[1.0, 5.0], [0.5, -0.5], [4.0, 5.0]]))


if __name__ == "__main__":
unittest.main()