Skip to content
3 changes: 3 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,11 @@
from .utils_pytorch_numpy_unification import (
any_np_pt,
clip,
concatenate,
cumsum,
floor_divide,
in1d,
isfinite,
maximum,
moveaxis,
nonzero,
Expand Down
19 changes: 10 additions & 9 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,20 +792,25 @@ class RandWeightedCrop(Randomizable, Transform):
It should be a single-channel array in shape, for example, `(1, spatial_dim_0, spatial_dim_1, ...)`.
"""

backend = SpatialCrop.backend

def __init__(
self, spatial_size: Union[Sequence[int], int], num_samples: int = 1, weight_map: Optional[np.ndarray] = None
self,
spatial_size: Union[Sequence[int], int],
num_samples: int = 1,
weight_map: Optional[NdarrayOrTensor] = None,
):
self.spatial_size = ensure_tuple(spatial_size)
self.num_samples = int(num_samples)
self.weight_map = weight_map
self.centers: List[np.ndarray] = []

def randomize(self, weight_map: np.ndarray) -> None:
def randomize(self, weight_map: NdarrayOrTensor) -> None:
self.centers = weighted_patch_samples(
spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R
) # using only the first channel as weight map

def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) -> List[np.ndarray]:
def __call__(self, img: NdarrayOrTensor, weight_map: Optional[NdarrayOrTensor] = None) -> List[NdarrayOrTensor]:
"""
Args:
img: input image to sample patches from. assuming `img` is a channel-first array.
Expand All @@ -816,23 +821,19 @@ def __call__(self, img: np.ndarray, weight_map: Optional[np.ndarray] = None) ->
Returns:
A list of image patches
"""
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
if weight_map is None:
weight_map = self.weight_map
if weight_map is None:
raise ValueError("weight map must be provided for weighted patch sampling.")
if img.shape[1:] != weight_map.shape[1:]:
raise ValueError(f"image and weight map spatial shape mismatch: {img.shape[1:]} vs {weight_map.shape[1:]}.")

weight_map, *_ = convert_data_type(weight_map, np.ndarray) # type: ignore

self.randomize(weight_map)
_spatial_size = fall_back_tuple(self.spatial_size, weight_map.shape[1:])
results = []
results: List[NdarrayOrTensor] = []
for center in self.centers:
cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size)
cropped: np.ndarray = cropper(img) # type: ignore
results.append(cropped)
results.append(cropper(img))
return results


Expand Down
15 changes: 8 additions & 7 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,8 @@ class RandWeightedCropd(Randomizable, MapTransform, InvertibleTransform):
:py:class:`monai.transforms.RandWeightedCrop`
"""

backend = SpatialCrop.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -951,18 +953,18 @@ def __init__(
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
self.centers: List[np.ndarray] = []

def randomize(self, weight_map: np.ndarray) -> None:
def randomize(self, weight_map: NdarrayOrTensor) -> None:
self.centers = weighted_patch_samples(
spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R
)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashable, NdarrayOrTensor]]:
d = dict(data)
self.randomize(d[self.w_key])
_spatial_size = fall_back_tuple(self.spatial_size, d[self.w_key].shape[1:])

# initialize returned list with shallow copy to preserve key ordering
results: List[Dict[Hashable, np.ndarray]] = [dict(data) for _ in range(self.num_samples)]
results: List[Dict[Hashable, NdarrayOrTensor]] = [dict(data) for _ in range(self.num_samples)]
# fill in the extra keys with unmodified data
for i in range(self.num_samples):
for key in set(data.keys()).difference(set(self.keys)):
Expand All @@ -977,8 +979,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n
for i, center in enumerate(self.centers):
cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size)
orig_size = img.shape[1:]
cropped: np.ndarray = cropper(img) # type: ignore
results[i][key] = cropped
results[i][key] = cropper(img)
self.push_transform(results[i], key, extra_info={"center": center}, orig_size=orig_size)
if self.center_coord_key:
results[i][self.center_coord_key] = center
Expand All @@ -989,7 +990,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, n
meta_key = meta_key or f"{key}_{meta_key_postfix}"
if meta_key not in results[i]:
results[i][meta_key] = {} # type: ignore
results[i][meta_key][Key.PATCH_INDEX] = i
results[i][meta_key][Key.PATCH_INDEX] = i # type: ignore

return results

Expand All @@ -1001,7 +1002,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE])
current_size = np.asarray(d[key].shape[1:])
center = transform[InverseKeys.EXTRA_INFO]["center"]
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size)
cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size)
# get required pad to start and end
pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)])
pad_to_end = orig_size - current_size - pad_to_start
Expand Down
9 changes: 6 additions & 3 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
create_translate,
map_spatial_axes,
)
from monai.transforms.utils_pytorch_numpy_unification import concatenate
from monai.utils import (
GridSampleMode,
GridSamplePadMode,
Expand Down Expand Up @@ -1977,6 +1978,8 @@ class AddCoordinateChannels(Transform):
Liu, R. et al. An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution, NeurIPS 2018.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
spatial_channels: Sequence[int],
Expand All @@ -1989,12 +1992,11 @@ def __init__(
"""
self.spatial_channels = spatial_channels

def __call__(self, img: Union[np.ndarray, torch.Tensor]):
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
img: data to be transformed, assuming `img` is channel first.
"""
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
if max(self.spatial_channels) > img.ndim - 1:
raise ValueError(
f"input has {img.ndim-1} spatial dimensions, cannot add AddCoordinateChannels channel for "
Expand All @@ -2005,7 +2007,8 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]):

spatial_dims = img.shape[1:]
coord_channels = np.array(np.meshgrid(*tuple(np.linspace(-0.5, 0.5, s) for s in spatial_dims), indexing="ij"))
coord_channels, *_ = convert_to_dst_type(coord_channels, img) # type: ignore
# only keep required dimensions. need to subtract 1 since im will be 0-based
# but user input is 1-based (because channel dim is 0)
coord_channels = coord_channels[[s - 1 for s in self.spatial_channels]]
return np.concatenate((img, coord_channels), axis=0)
return concatenate((img, coord_channels), axis=0)
6 changes: 3 additions & 3 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,6 +1758,8 @@ class AddCoordinateChannelsd(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.AddCoordinateChannels`.
"""

backend = AddCoordinateChannels.backend

def __init__(self, keys: KeysCollection, spatial_channels: Sequence[int], allow_missing_keys: bool = False) -> None:
"""
Args:
Expand All @@ -1772,9 +1774,7 @@ def __init__(self, keys: KeysCollection, spatial_channels: Sequence[int], allow_
super().__init__(keys, allow_missing_keys)
self.add_coordinate_channels = AddCoordinateChannels(spatial_channels)

def __call__(
self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]
) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.add_coordinate_channels(d[key])
Expand Down
20 changes: 10 additions & 10 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
map_binary_to_indices,
map_classes_to_indices,
)
from monai.transforms.utils_pytorch_numpy_unification import in1d, moveaxis, unravel_indices
from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, moveaxis, unravel_indices
from monai.utils import (
convert_data_type,
convert_to_cupy,
Expand All @@ -45,6 +45,7 @@
)
from monai.utils.enums import TransformBackends
from monai.utils.misc import is_module_ver_at_least
from monai.utils.type_conversion import convert_to_dst_type

PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray")
Expand Down Expand Up @@ -915,22 +916,24 @@ class AddExtremePointsChannel(Randomizable, Transform):
ValueError: When label image is not single channel.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, background: int = 0, pert: float = 0.0) -> None:
self._background = background
self._pert = pert
self._points: List[Tuple[int, ...]] = []

def randomize(self, label: np.ndarray) -> None:
def randomize(self, label: NdarrayOrTensor) -> None:
self._points = get_extreme_points(label, rand_state=self.R, background=self._background, pert=self._pert)

def __call__(
self,
img: np.ndarray,
label: Optional[np.ndarray] = None,
img: NdarrayOrTensor,
label: Optional[NdarrayOrTensor] = None,
sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 3.0,
rescale_min: float = -1.0,
rescale_max: float = 1.0,
):
) -> NdarrayOrTensor:
"""
Args:
img: the image that we want to add new channel to.
Expand All @@ -947,17 +950,14 @@ def __call__(
if label.shape[0] != 1:
raise ValueError("Only supports single channel labels!")

img, *_ = convert_data_type(img, np.ndarray) # type: ignore
label, *_ = convert_data_type(label, np.ndarray) # type: ignore

# Generate extreme points
self.randomize(label[0, :])

points_image = extreme_points_to_image(
points=self._points, label=label, sigma=sigma, rescale_min=rescale_min, rescale_max=rescale_max
)

return np.concatenate([img, points_image], axis=0)
points_image, *_ = convert_to_dst_type(points_image, img) # type: ignore
return concatenate((img, points_image), axis=0)


class TorchVision:
Expand Down
12 changes: 9 additions & 3 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform
from monai.transforms.utility.array import (
AddChannel,
AddExtremePointsChannel,
AsChannelFirst,
AsChannelLast,
CastToType,
Expand Down Expand Up @@ -59,8 +60,10 @@
Transpose,
)
from monai.transforms.utils import extreme_points_to_image, get_extreme_points
from monai.transforms.utils_pytorch_numpy_unification import concatenate
from monai.utils import convert_to_numpy, ensure_tuple, ensure_tuple_rep
from monai.utils.enums import InverseKeys, TransformBackends
from monai.utils.type_conversion import convert_to_dst_type

__all__ = [
"AddChannelD",
Expand Down Expand Up @@ -1231,6 +1234,8 @@ class AddExtremePointsChanneld(Randomizable, MapTransform):

"""

backend = AddExtremePointsChannel.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -1251,10 +1256,10 @@ def __init__(
self.rescale_min = rescale_min
self.rescale_max = rescale_max

def randomize(self, label: np.ndarray) -> None:
def randomize(self, label: NdarrayOrTensor) -> None:
self.points = get_extreme_points(label, rand_state=self.R, background=self.background, pert=self.pert)

def __call__(self, data):
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
label = d[self.label_key]
if label.shape[0] != 1:
Expand All @@ -1272,7 +1277,8 @@ def __call__(self, data):
rescale_min=self.rescale_min,
rescale_max=self.rescale_max,
)
d[key] = np.concatenate([img, points_image], axis=0)
points_image, *_ = convert_to_dst_type(points_image, img) # type: ignore
d[key] = concatenate([img, points_image], axis=0)
return d


Expand Down
Loading