Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions monai/apps/deepgrow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@
from monai.transforms import Resize, SpatialCrop
from monai.transforms.transform import MapTransform, Randomizable, Transform
from monai.transforms.utils import generate_spatial_bounding_box
from monai.utils import InterpolateMode, deprecated_arg, ensure_tuple, ensure_tuple_rep, min_version, optional_import
from monai.utils import (
InterpolateMode,
deprecated_arg,
ensure_tuple,
ensure_tuple_rep,
first,
min_version,
optional_import,
)

measure, _ = optional_import("skimage.measure", "0.14.2", min_version)
distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")
Expand Down Expand Up @@ -645,7 +653,7 @@ def bounding_box(self, points, img_shape):
def __call__(self, data):
d: Dict = dict(data)
guidance = d[self.guidance]
original_spatial_shape = d[self.keys[0]].shape[1:]
original_spatial_shape = d[first(self.key_iterator(d))].shape[1:]
box_start, box_end = self.bounding_box(np.array(guidance[0] + guidance[1]), original_spatial_shape)
center = list(np.mean([box_start, box_end], axis=0).astype(int))
spatial_size = self.spatial_size
Expand Down
8 changes: 4 additions & 4 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
weighted_patch_samples,
)
from monai.utils import ImageMetaKey as Key
from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple
from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, first
from monai.utils.enums import InverseKeys

__all__ = [
Expand Down Expand Up @@ -481,7 +481,7 @@ def __init__(
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
# use the spatial size of first image to scale, expect all images have the same spatial size
img_size = data[self.keys[0]].shape[1:]
img_size = d[first(self.key_iterator(d))].shape[1:]
ndim = len(img_size)
roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)]
cropper = CenterSpatialCrop(roi_size)
Expand Down Expand Up @@ -575,7 +575,7 @@ def randomize(self, img_size: Sequence[int]) -> None:

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
self.randomize(d[self.keys[0]].shape[1:]) # image shape from the first data key
self.randomize(d[first(self.key_iterator(d))].shape[1:]) # image shape from the first data key
if self._size is None:
raise RuntimeError("self._size not specified.")
for key in self.key_iterator(d):
Expand Down Expand Up @@ -669,7 +669,7 @@ def __init__(
self.max_roi_scale = max_roi_scale

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
img_size = data[self.keys[0]].shape[1:]
img_size = data[first(self.key_iterator(data))].shape[1:] # type: ignore
ndim = len(img_size)
self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)]
if self.max_roi_scale is not None:
Expand Down
10 changes: 5 additions & 5 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
)
from monai.transforms.transform import MapTransform, RandomizableTransform
from monai.transforms.utils import is_positive
from monai.utils import ensure_tuple, ensure_tuple_rep
from monai.utils import ensure_tuple, ensure_tuple_rep, first
from monai.utils.deprecate_utils import deprecated_arg

__all__ = [
Expand Down Expand Up @@ -187,7 +187,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
return d

# all the keys share the same random noise
self.rand_gaussian_noise.randomize(d[self.keys[0]])
self.rand_gaussian_noise.randomize(d[first(self.key_iterator(d))])
for key in self.key_iterator(d):
d[key] = self.rand_gaussian_noise(img=d[key], randomize=False)
return d
Expand Down Expand Up @@ -621,7 +621,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
return d

# all the keys share the same random bias factor
self.rand_bias_field.randomize(img_size=d[self.keys[0]].shape[1:])
self.rand_bias_field.randomize(img_size=d[first(self.key_iterator(d))].shape[1:])
for key in self.key_iterator(d):
d[key] = self.rand_bias_field(d[key], randomize=False)
return d
Expand Down Expand Up @@ -1466,7 +1466,7 @@ def __call__(self, data):
return d

# expect all the specified keys have same spatial shape and share same random holes
self.dropper.randomize(d[self.keys[0]].shape[1:])
self.dropper.randomize(d[first(self.key_iterator(d))].shape[1:])
for key in self.key_iterator(d):
d[key] = self.dropper(img=d[key], randomize=False)

Expand Down Expand Up @@ -1531,7 +1531,7 @@ def __call__(self, data):
return d

# expect all the specified keys have same spatial shape and share same random holes
self.shuffle.randomize(d[self.keys[0]].shape[1:])
self.shuffle.randomize(d[first(self.key_iterator(d))].shape[1:])
for key in self.key_iterator(d):
d[key] = self.shuffle(img=d[key], randomize=False)

Expand Down
16 changes: 9 additions & 7 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
ensure_tuple,
ensure_tuple_rep,
fall_back_tuple,
first,
)
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.enums import InverseKeys
Expand Down Expand Up @@ -817,9 +818,10 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N

device = self.rand_affine.resampler.device

sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:])
spatial_size = d[first(self.key_iterator(d))].shape[1:]
sp_size = fall_back_tuple(self.rand_affine.spatial_size, spatial_size)
# change image size or do random transform
do_resampling = self._do_transform or (sp_size != ensure_tuple(data[self.keys[0]].shape[1:]))
do_resampling = self._do_transform or (sp_size != ensure_tuple(spatial_size))
affine: torch.Tensor = torch.eye(len(sp_size) + 1, dtype=torch.float64, device=device)
# converting affine to tensor because the resampler currently only support torch backend
grid = None
Expand Down Expand Up @@ -977,7 +979,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
d = dict(data)
self.randomize(None)

sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, data[self.keys[0]].shape[1:])
sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, d[first(self.key_iterator(d))].shape[1:])
# all the keys share the same random elastic factor
self.rand_2d_elastic.randomize(sp_size)

Expand Down Expand Up @@ -1109,7 +1111,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
d = dict(data)
self.randomize(None)

sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, data[self.keys[0]].shape[1:])
sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, d[first(self.key_iterator(d))].shape[1:])
# all the keys share the same random elastic factor
self.rand_3d_elastic.randomize(sp_size)

Expand Down Expand Up @@ -1260,7 +1262,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
self.randomize(None)

# all the keys share the same random selected axis
self.flipper.randomize(d[self.keys[0]])
self.flipper.randomize(d[first(self.key_iterator(d))])
for key in self.key_iterator(d):
if self._do_transform:
d[key] = self.flipper(d[key], randomize=False)
Expand Down Expand Up @@ -1684,7 +1686,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
self.randomize(None)

# all the keys share the same random zoom factor
self.rand_zoom.randomize(d[self.keys[0]])
self.rand_zoom.randomize(d[first(self.key_iterator(d))])
for key, mode, padding_mode, align_corners in self.key_iterator(
d, self.mode, self.padding_mode, self.align_corners
):
Expand Down Expand Up @@ -1866,7 +1868,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
if not self._do_transform:
return d

self.rand_grid_distortion.randomize(d[self.keys[0]].shape[1:])
self.rand_grid_distortion.randomize(d[first(self.key_iterator(d))].shape[1:])
for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
d[key] = self.rand_grid_distortion(d[key], mode=mode, padding_mode=padding_mode, randomize=False)
return d
Expand Down
3 changes: 2 additions & 1 deletion tests/test_rand_scale_cropd.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
]

TEST_CASE_2 = [
{"keys": "img", "roi_scale": [1.0, 1.0, 1.0], "random_center": False},
# test `allow_missing_keys` with key "label"
{"keys": ["label", "img"], "roi_scale": [1.0, 1.0, 1.0], "random_center": False, "allow_missing_keys": True},
{"img": np.random.randint(0, 2, size=[3, 3, 3, 3])},
(3, 3, 3, 3),
]
Expand Down