Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import warnings
from copy import deepcopy
from enum import Enum
from itertools import zip_longest
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
Expand All @@ -24,7 +25,7 @@
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
from monai.networks.utils import meshgrid_ij, normalize_transform
from monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCrop
Expand Down Expand Up @@ -437,6 +438,8 @@ def __init__(
dtype: DtypeLike = np.float64,
scale_extent: bool = False,
recompute_affine: bool = False,
min_pixdim: Union[Sequence[float], float, np.ndarray, None] = None,
max_pixdim: Union[Sequence[float], float, np.ndarray, None] = None,
image_only: bool = False,
) -> None:
"""
Expand Down Expand Up @@ -483,13 +486,25 @@ def __init__(
recompute_affine: whether to recompute affine based on the output shape. The affine computed
analytically does not reflect the potential quantization errors in terms of the output shape.
Set this flag to True to recompute the output affine based on the actual pixdim. Default to ``False``.
min_pixdim: minimal input spacing to be resampled. If provided, input image with a larger spacing than this
value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the
value of `pixdim`. Default to `None`.
max_pixdim: maximal input spacing to be resampled. If provided, input image with a smaller spacing than this
value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the
value of `pixdim`. Default to `None`.

"""
self.pixdim = np.array(ensure_tuple(pixdim), dtype=np.float64)
self.min_pixdim = np.array(ensure_tuple(min_pixdim), dtype=np.float64)
Comment thread
wyli marked this conversation as resolved.
self.max_pixdim = np.array(ensure_tuple(max_pixdim), dtype=np.float64)
self.diagonal = diagonal
self.scale_extent = scale_extent
self.recompute_affine = recompute_affine

for mn, mx in zip(self.min_pixdim, self.max_pixdim):
if (not np.isnan(mn)) and (not np.isnan(mx)) and ((mx < mn) or (mn < 0)):
raise ValueError(f"min_pixdim {self.min_pixdim} must be positive, smaller than max {self.max_pixdim}.")

self.sp_resample = SpatialResample(
mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype
)
Expand Down Expand Up @@ -560,6 +575,16 @@ def __call__(
out_d = self.pixdim[:sr]
if out_d.size < sr:
out_d = np.append(out_d, [1.0] * (sr - out_d.size))
orig_d = affine_to_spacing(affine_, sr, out_d.dtype)
for idx, (_d, mn, mx) in enumerate(
zip_longest(orig_d, self.min_pixdim[:sr], self.max_pixdim[:sr], fillvalue=np.nan)
):
target = out_d[idx]
mn = target if np.isnan(mn) else min(mn, target)
mx = target if np.isnan(mx) else max(mx, target)
if mn > mx:
raise ValueError(f"min_pixdim is larger than max_pixdim at dim {idx}: min {mn} max {mx} out {target}.")
out_d[idx] = _d if (mn - AFFINE_TOL) <= _d <= (mx + AFFINE_TOL) else target

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. if we define min and max pixdim, why do we need additional affine_tol here?
  2. if the resolution is withing the min / max, we simply want to skip the resampling completely. In this code, it seems all it does is records the original resolution and still does resampling in the next step (which will incur additional compute load and possible artifacts)


if not align_corners and scale_extent:
warnings.warn("align_corners=False is not compatible with scale_extent=True.")
Expand Down
12 changes: 11 additions & 1 deletion monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,8 @@ def __init__(
recompute_affine: bool = False,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
min_pixdim: Union[Sequence[float], float, None] = None,
max_pixdim: Union[Sequence[float], float, None] = None,
allow_missing_keys: bool = False,
) -> None:
"""
Expand Down Expand Up @@ -386,11 +388,19 @@ def __init__(
recompute_affine: whether to recompute affine based on the output shape. The affine computed
analytically does not reflect the potential quantization errors in terms of the output shape.
Set this flag to True to recompute the output affine based on the actual pixdim. Default to ``False``.
min_pixdim: minimal input spacing to be resampled. If provided, input image with a larger spacing than this
value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the
value of `pixdim`. Default to `None`.
max_pixdim: maximal input spacing to be resampled. If provided, input image with a smaller spacing than this
value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the
value of `pixdim`. Default to `None`.
allow_missing_keys: don't raise exception if key is missing.

"""
super().__init__(keys, allow_missing_keys)
self.spacing_transform = Spacing(pixdim, diagonal=diagonal, recompute_affine=recompute_affine)
self.spacing_transform = Spacing(
pixdim, diagonal=diagonal, recompute_affine=recompute_affine, min_pixdim=min_pixdim, max_pixdim=max_pixdim
)
self.mode = ensure_tuple_rep(mode, len(self.keys))
self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys))
self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))
Expand Down
30 changes: 28 additions & 2 deletions tests/test_spacing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import affine_to_spacing
from monai.transforms import Spacing
from monai.utils import ensure_tuple, fall_back_tuple
from monai.utils import fall_back_tuple
from tests.utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose

TESTS = []
Expand Down Expand Up @@ -245,7 +245,6 @@ def test_spacing(self, init_param, img, affine, data_param, expected_output, dev
sr = min(len(res.shape) - 1, 3)
if isinstance(init_param["pixdim"], float):
init_param["pixdim"] = [init_param["pixdim"]] * sr
init_pixdim = ensure_tuple(init_param["pixdim"])
init_pixdim = init_param["pixdim"][:sr]
norm = affine_to_spacing(res.affine, sr).cpu().numpy()
assert_allclose(fall_back_tuple(init_pixdim, norm), norm, type_test=False)
Expand Down Expand Up @@ -287,6 +286,33 @@ def test_inverse(self, device, recompute, align, scale_extent):
l2_norm_affine = ((affine - img.affine) ** 2).sum() ** 0.5
self.assertLess(l2_norm_affine, 5e-2)

@parameterized.expand(TEST_INVERSE)
def test_inverse_mn_mx(self, device, recompute, align, scale_extent):
img_t = torch.rand((1, 10, 9, 8), dtype=torch.float32, device=device)
affine = torch.tensor(
[[0, 0, -1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=torch.float32, device="cpu"
)
img = MetaTensor(img_t, affine=affine, meta={"fname": "somewhere"})
choices = [(None, None), [1.2, None], [None, 0.7], [0.7, 0.9]]
idx = np.random.choice(range(len(choices)), size=1)[0]
tr = Spacing(
pixdim=[1.1, 1.2, 0.9],
recompute_affine=recompute,
align_corners=align,
scale_extent=scale_extent,
min_pixdim=[0.9, None, choices[idx][0]],
max_pixdim=[1.1, 1.1, choices[idx][1]],
)
img_out = tr(img)
if isinstance(img_out, MetaTensor):
assert_allclose(
img_out.pixdim, [1.0, 1.125, 0.888889] if recompute else [1.0, 1.2, 0.9], type_test=False, rtol=1e-4
)
img_out = tr.inverse(img_out)
self.assertEqual(img_out.applied_operations, [])
self.assertEqual(img_out.shape, img_t.shape)
self.assertLess(((affine - img_out.affine) ** 2).sum() ** 0.5, 5e-2)


if __name__ == "__main__":
unittest.main()