Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,8 @@ def threshold_at_one(x):

"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
select_fn: Callable = is_positive,
Expand Down Expand Up @@ -728,13 +730,15 @@ def __init__(
self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode)
self.np_kwargs = np_kwargs

def compute_bounding_box(self, img: np.ndarray):
def compute_bounding_box(self, img: NdarrayOrTensor) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute the start points and end points of bounding box to crop.
And adjust bounding box coords to be divisible by `k`.

"""
box_start, box_end = generate_spatial_bounding_box(img, self.select_fn, self.channel_indices, self.margin)
box_start = [i.cpu() if isinstance(i, torch.Tensor) else i for i in box_start] # type: ignore
box_end = [i.cpu() if isinstance(i, torch.Tensor) else i for i in box_end] # type: ignore
box_start_ = np.asarray(box_start, dtype=np.int16)
box_end_ = np.asarray(box_end, dtype=np.int16)
orig_spatial_size = box_end_ - box_start_
Expand All @@ -747,7 +751,7 @@ def compute_bounding_box(self, img: np.ndarray):

def crop_pad(
self,
img: np.ndarray,
img: NdarrayOrTensor,
box_start: np.ndarray,
box_end: np.ndarray,
mode: Optional[Union[NumpyPadMode, str]] = None,
Expand All @@ -762,13 +766,11 @@ def crop_pad(
pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist())))
return BorderPad(spatial_border=pad, mode=mode or self.mode, **self.np_kwargs)(cropped)

def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None):
def __call__(self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, str]] = None):
"""
Apply the transform to `img`, assuming `img` is channel-first and
slicing doesn't change the channel dim.
"""
img, *_ = convert_data_type(img, np.ndarray) # type: ignore

box_start, box_end = self.compute_bounding_box(img)
cropped = self.crop_pad(img, box_start, box_end, mode)

Expand Down
20 changes: 12 additions & 8 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,8 @@ class RandHistogramShift(RandomizableTransform):
prob: probability of histogram shift.
"""

backend = [TransformBackends.NUMPY]

def __init__(self, num_control_points: Union[Tuple[int, int], int] = 10, prob: float = 0.1) -> None:
RandomizableTransform.__init__(self, prob)

Expand All @@ -1273,16 +1275,17 @@ def randomize(self, data: Optional[Any] = None) -> None:
self.floating_control_points[i - 1], self.floating_control_points[i + 1]
)

def __call__(self, img: np.ndarray) -> np.ndarray:
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
def __call__(self, img: NdarrayOrTensor) -> np.ndarray:
img_np: np.ndarray
img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore
self.randomize()
if not self._do_transform:
return img
img_min, img_max = img.min(), img.max()
return img_np
img_min, img_max = img_np.min(), img_np.max()
reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min
floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min
return np.asarray(
np.interp(img, reference_control_points_scaled, floating_control_points_scaled), dtype=img.dtype
np.interp(img_np, reference_control_points_scaled, floating_control_points_scaled), dtype=img_np.dtype
)


Expand Down Expand Up @@ -1902,12 +1905,14 @@ class HistogramNormalize(Transform):

"""

backend = [TransformBackends.NUMPY]

def __init__(
self,
num_bins: int = 256,
min: int = 0,
max: int = 255,
mask: Optional[np.ndarray] = None,
mask: Optional[NdarrayOrTensor] = None,
dtype: DtypeLike = np.float32,
) -> None:
self.num_bins = num_bins
Expand All @@ -1916,8 +1921,7 @@ def __init__(
self.mask = mask
self.dtype = dtype

def __call__(self, img: np.ndarray, mask: Optional[np.ndarray] = None) -> np.ndarray:
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
def __call__(self, img: NdarrayOrTensor, mask: Optional[NdarrayOrTensor] = None) -> np.ndarray:
return equalize_hist(
img=img,
mask=mask if mask is not None else self.mask,
Expand Down
27 changes: 17 additions & 10 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from monai.transforms.utils import is_positive
from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep, ensure_tuple_size
from monai.utils.deprecated import deprecated_arg
from monai.utils.enums import TransformBackends
from monai.utils.type_conversion import convert_data_type

__all__ = [
Expand Down Expand Up @@ -1108,6 +1109,8 @@ class RandHistogramShiftd(RandomizableTransform, MapTransform):
allow_missing_keys: don't raise exception if key is missing.
"""

backend = [TransformBackends.NUMPY]

def __init__(
self,
keys: KeysCollection,
Expand Down Expand Up @@ -1138,17 +1141,19 @@ def randomize(self, data: Optional[Any] = None) -> None:
self.floating_control_points[i - 1], self.floating_control_points[i + 1]
)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
self.randomize()
if not self._do_transform:
return d
for key in self.key_iterator(d):
img_min, img_max = d[key].min(), d[key].max()
reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min
floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min
dtype = d[key].dtype
d[key] = np.interp(d[key], reference_control_points_scaled, floating_control_points_scaled).astype(dtype)
d[key] = convert_data_type(d[key], np.ndarray)[0]
if self._do_transform:
img_min, img_max = d[key].min(), d[key].max()
reference_control_points_scaled = self.reference_control_points * (img_max - img_min) + img_min
floating_control_points_scaled = self.floating_control_points * (img_max - img_min) + img_min
dtype = d[key].dtype
d[key] = np.interp(d[key], reference_control_points_scaled, floating_control_points_scaled).astype(
dtype
)
return d


Expand Down Expand Up @@ -1594,13 +1599,15 @@ class HistogramNormalized(MapTransform):

"""

backend = HistogramNormalize.backend

def __init__(
self,
keys: KeysCollection,
num_bins: int = 256,
min: int = 0,
max: int = 255,
mask: Optional[np.ndarray] = None,
mask: Optional[NdarrayOrTensor] = None,
mask_key: Optional[str] = None,
dtype: DtypeLike = np.float32,
allow_missing_keys: bool = False,
Expand All @@ -1609,7 +1616,7 @@ def __init__(
self.transform = HistogramNormalize(num_bins=num_bins, min=min, max=max, mask=mask, dtype=dtype)
self.mask_key = mask_key if mask is None else None

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.transform(d[key], d[self.mask_key]) if self.mask_key is not None else self.transform(d[key])
Expand Down
37 changes: 24 additions & 13 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from monai.networks.layers import GaussianFilter
from monai.transforms.compose import Compose, OneOf
from monai.transforms.transform import MapTransform, Transform, apply_transform
from monai.transforms.utils_pytorch_numpy_unification import any_np_pt, nonzero, ravel, unravel_index
from monai.transforms.utils_pytorch_numpy_unification import any_np_pt, nonzero, ravel, unravel_index, where
from monai.utils import (
GridSampleMode,
InterpolateMode,
Expand Down Expand Up @@ -839,7 +839,7 @@ def _create_translate(


def generate_spatial_bounding_box(
img: np.ndarray,
img: NdarrayOrTensor,
select_fn: Callable = is_positive,
channel_indices: Optional[IndexSelection] = None,
margin: Union[Sequence[int], int] = 0,
Expand All @@ -864,7 +864,7 @@ def generate_spatial_bounding_box(
margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims.
"""
data = img[list(ensure_tuple(channel_indices))] if channel_indices is not None else img
data = np.any(select_fn(data), axis=0)
data = select_fn(data).any(0)
ndim = len(data.shape)
margin = ensure_tuple_rep(margin, ndim)
for m in margin:
Expand All @@ -875,13 +875,18 @@ def generate_spatial_bounding_box(
box_end = [0] * ndim

for di, ax in enumerate(itertools.combinations(reversed(range(ndim)), ndim - 1)):
dt = data.any(axis=ax)
if not np.any(dt):
dt = data
if len(ax) != 0:
dt = any_np_pt(dt, ax)

if not dt.any():
# if no foreground, return all zero bounding box coords
return [0] * ndim, [0] * ndim

min_d = max(np.argmax(dt) - margin[di], 0)
max_d = max(data.shape[di] - max(np.argmax(dt[::-1]) - margin[di], 0), min_d + 1)
arg_max = where(dt == dt.max())[0]
min_d = max(arg_max[0] - margin[di], 0)
max_d = arg_max[-1] + margin[di] + 1

box_start[di], box_end[di] = min_d, max_d

return box_start, box_end
Expand Down Expand Up @@ -1203,8 +1208,8 @@ def compute_divisible_spatial_size(spatial_shape: Sequence[int], k: Union[Sequen


def equalize_hist(
img: np.ndarray,
mask: Optional[np.ndarray] = None,
img: NdarrayOrTensor,
mask: Optional[NdarrayOrTensor] = None,
num_bins: int = 256,
min: int = 0,
max: int = 255,
Expand All @@ -1226,8 +1231,14 @@ def equalize_hist(
dtype: data type of the output, default to `float32`.

"""
orig_shape = img.shape
hist_img = img[np.array(mask, dtype=bool)] if mask is not None else img
img_np: np.ndarray
img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore
mask_np: Optional[np.ndarray] = None
if mask is not None:
mask_np, *_ = convert_data_type(mask, np.ndarray) # type: ignore

orig_shape = img_np.shape
hist_img = img_np[np.array(mask_np, dtype=bool)] if mask_np is not None else img_np
if has_skimage:
hist, bins = exposure.histogram(hist_img.flatten(), num_bins)
else:
Expand All @@ -1239,9 +1250,9 @@ def equalize_hist(
cum = rescale_array(arr=cum, minv=min, maxv=max)

# apply linear interpolation
img = np.interp(img.flatten(), bins, cum)
img_np = np.interp(img_np.flatten(), bins, cum)

return img.reshape(orig_shape).astype(dtype)
return img_np.reshape(orig_shape).astype(dtype)


class Fourier:
Expand Down
33 changes: 22 additions & 11 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union
from typing import Sequence, Union

import numpy as np
import torch
Expand Down Expand Up @@ -114,17 +114,23 @@ def percentile(x: NdarrayOrTensor, q) -> Union[NdarrayOrTensor, float, int]:
return result


def where(condition: NdarrayOrTensor, x, y) -> NdarrayOrTensor:
def where(condition: NdarrayOrTensor, x=None, y=None) -> NdarrayOrTensor:
"""
Note that `torch.where` may convert y.dtype to x.dtype.
"""
result: NdarrayOrTensor
if isinstance(condition, np.ndarray):
result = np.where(condition, x, y)
if x is not None:
result = np.where(condition, x, y)
else:
result = np.where(condition)
else:
x = torch.as_tensor(x, device=condition.device)
y = torch.as_tensor(y, device=condition.device, dtype=x.dtype)
result = torch.where(condition, x, y)
if x is not None:
x = torch.as_tensor(x, device=condition.device)
y = torch.as_tensor(y, device=condition.device, dtype=x.dtype)
result = torch.where(condition, x, y)
else:
result = torch.where(condition) # type: ignore
return result


Expand Down Expand Up @@ -211,7 +217,7 @@ def ravel(x: NdarrayOrTensor):
return np.ravel(x)


def any_np_pt(x: NdarrayOrTensor, axis: int):
def any_np_pt(x: NdarrayOrTensor, axis: Union[int, Sequence[int]]):
"""`np.any` with equivalent implementation for torch.

For pytorch, convert to boolean for compatibility with older versions.
Expand All @@ -223,13 +229,18 @@ def any_np_pt(x: NdarrayOrTensor, axis: int):
Returns:
Return a contiguous flattened array/tensor.
"""
if isinstance(x, torch.Tensor):
if isinstance(x, np.ndarray):
return np.any(x, axis)

# pytorch can't handle multiple dimensions to `any` so loop across them
axis = [axis] if not isinstance(axis, Sequence) else axis
for ax in axis:
try:
return torch.any(x, axis)
x = torch.any(x, ax)
except RuntimeError:
# older versions of pytorch require the input to be cast to boolean
return torch.any(x.bool(), axis)
return np.any(x, axis)
x = torch.any(x.bool(), ax)
return x


def maximum(a: NdarrayOrTensor, b: NdarrayOrTensor) -> NdarrayOrTensor:
Expand Down
Loading