From dad9365e282dbfb2bf092ea73a3b1f7cd964ba4f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 20 Jan 2022 15:39:38 +0000 Subject: [PATCH 01/33] temp spatial_resample Signed-off-by: Wenqi Li --- monai/data/utils.py | 60 +++++--- monai/transforms/spatial/array.py | 188 ++++++++++++++++++++----- monai/transforms/spatial/dictionary.py | 9 ++ monai/utils/module.py | 4 +- tests/test_spacing.py | 4 +- 5 files changed, 207 insertions(+), 58 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 79ef9bd7fb..2bf54cc6ab 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -32,7 +32,9 @@ from monai.utils import ( MAX_SEED, BlendMode, + Method, NumpyPadMode, + convert_data_type, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, @@ -42,7 +44,6 @@ look_up_option, optional_import, ) -from monai.utils.enums import Method pd, _ = optional_import("pandas") DataFrame, _ = optional_import("pandas", name="DataFrame") @@ -78,6 +79,7 @@ "no_collation", "convert_tables_to_dicts", "SUPPORTED_PICKLE_MOD", + "reorient_spatial_axes", ] # module to be used by `torch.save` @@ -679,56 +681,54 @@ def compute_shape_offset( corners = np.asarray(np.meshgrid(*in_coords, indexing="ij")).reshape((len(shape), -1)) corners = np.concatenate((corners, np.ones_like(corners[:1]))) corners = in_affine @ corners - corners_out = np.linalg.inv(out_affine) @ corners + inv_mat = np.linalg.inv(out_affine) + corners_out = inv_mat @ corners corners_out = corners_out[:-1] / corners_out[-1] out_shape = np.round(corners_out.ptp(axis=1) + 1.0) - if np.allclose(nib.io_orientation(in_affine), nib.io_orientation(out_affine)): - # same orientation, get translate from the origin - offset = in_affine @ ([0] * sr + [1]) - offset = offset[:-1] / offset[-1] - else: - # different orientation, the min is the origin - corners = corners[:-1] / corners[-1] - offset = np.min(corners, 1) + mat = inv_mat[:-1, :-1] + k = 0 + for i in range(corners.shape[1]): + min_corner = np.min(mat @ corners[:-1, :] - mat @ corners[:-1, i : i + 1], 1) + if np.allclose(min_corner, 0.0): + k = i + break + offset = corners[:-1, k] return out_shape.astype(int, copy=False), offset -def to_affine_nd(r: Union[np.ndarray, int], affine: np.ndarray) -> np.ndarray: +def to_affine_nd(r: Union[np.ndarray, int], affine, dtype=np.float64) -> np.ndarray: """ Using elements from affine, to create a new affine matrix by assigning the rotation/zoom/scaling matrix and the translation vector. - when ``r`` is an integer, output is an (r+1)x(r+1) matrix, where the top left kxk elements are copied from ``affine``, the last column of the output affine is copied from ``affine``'s last column. `k` is determined by `min(r, len(affine) - 1)`. - - when ``r`` is an affine matrix, the output has the same as ``r``, - the top left kxk elements are copied from ``affine``, + when ``r`` is an affine matrix, the output has the same shape as ``r``, + and the top left kxk elements are copied from ``affine``, the last column of the output affine is copied from ``affine``'s last column. `k` is determined by `min(len(r) - 1, len(affine) - 1)`. - Args: r (int or matrix): number of spatial dimensions or an output affine to be filled. affine (matrix): 2D affine matrix - + dtype: data type of the output array. Raises: ValueError: When ``affine`` dimensions is not 2. ValueError: When ``r`` is nonpositive. - Returns: an (r+1) x (r+1) matrix - """ - affine_np = np.array(affine, dtype=np.float64) + affine_np: np.ndarray + affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True)[0] # type: ignore + affine_np = affine_np.copy() if affine_np.ndim != 2: raise ValueError(f"affine must have 2 dimensions, got {affine_np.ndim}.") - new_affine = np.array(r, dtype=np.float64, copy=True) + new_affine = np.array(r, dtype=dtype, copy=True) if new_affine.ndim == 0: sr: int = int(new_affine.astype(np.uint)) if not np.isfinite(sr) or sr < 0: raise ValueError(f"r must be positive, got {sr}.") - new_affine = np.eye(sr + 1, dtype=np.float64) + new_affine = np.eye(sr + 1, dtype=dtype) d = max(min(len(new_affine) - 1, len(affine_np) - 1), 1) new_affine[:d, :d] = affine_np[:d, :d] if d > 1: @@ -736,6 +736,22 @@ def to_affine_nd(r: Union[np.ndarray, int], affine: np.ndarray) -> np.ndarray: return new_affine +def reorient_spatial_axes( + data_shape: Sequence[int], init_affine: np.ndarray, target_affine: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: + """ + Given the input ``data_array`` and its corresponding coordinate ``init_affine``, + convert the array to ``target_affine`` by rearranging/flipping the axes. + Returns the transformed array and the updated affine. + Note that this function requires external module ``nibabel.orientations``. + """ + start_ornt = nib.orientations.io_orientation(init_affine) + target_ornt = nib.orientations.io_orientation(target_affine) + ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) + new_affine = init_affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) + return ornt_transform, new_affine + + def create_file_basename( postfix: str, input_file_name: PathLike, diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index ee5d179b00..afd8fdc249 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -20,7 +20,7 @@ from monai.config import USE_COMPILED, DtypeLike from monai.config.type_definitions import NdarrayOrTensor -from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine +from monai.data.utils import compute_shape_offset, reorient_spatial_axes, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.networks.utils import meshgrid_ij from monai.transforms.croppad.array import CenterSpatialCrop, Pad @@ -52,9 +52,12 @@ from monai.utils.module import look_up_option from monai.utils.type_conversion import convert_data_type, convert_to_dst_type -nib, _ = optional_import("nibabel") +AFFINE_TOL = 1e-3 + +nib, has_nib = optional_import("nibabel") __all__ = [ + "SpatialResample", "Spacing", "Orientation", "Flip", @@ -82,6 +85,140 @@ RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] +class SpatialResample(Transform): + """ + Resample input image from the orientation/spacing defined by ``src`` affine into the ones specified by `dst`. + """ + + def __init__( + self, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: DtypeLike = np.float64, + ): + """ + Args: + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + align_corners: Geometrically, we consider the pixels of the input as squares rather than points. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + If ``None``, use the data type of input data. To be compatible with other modules, + the output data type is always ``np.float32``. + """ + self.mode = mode + self.padding_mode = padding_mode + self.align_corners = align_corners + self.dtype = dtype + + def __call__( + self, + img: NdarrayOrTensor, + src: Optional[NdarrayOrTensor] = None, + dst: Optional[NdarrayOrTensor] = None, + spatial_size: Optional[Union[Sequence[int], int]] = None, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: DtypeLike = np.float64, + ): + """ + Args: + img: input image to be resampled. It currently supports channel-first arrays with + at most three spatial dimensions. + src: source affine matrix. Defaults to ``None``, which means the identity matrix. + the shape should be `(r+1, r+1)` where `r` is the spatial rank of ``img``. + dst: destination affine matrix. Defaults to ``None``, which means the same as `src`. + the shape should be `(r+1, r+1)` where `r` is the spatial rank of ``img``. + when `dst` is None, the input will be returned without resampling, but the data type + will be `float32`. + spatial_size: output image spatial size. + if `spatial_size` and `self.spatial_size` are not defined, + the transform will compute a spatial size automatically containing the previous field of view. + if `spatial_size` is ``-1`` are the transform will use the corresponding input img size. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + align_corners: Geometrically, we consider the pixels of the input as squares rather than points. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + If ``None``, use the data type of input data. To be compatible with other modules, + the output data type is always `float32`. + """ + spatial_rank = min(len(img.shape) - 1, 3) + if src is None: + src = np.eye(4, dtype=np.float64) + src = to_affine_nd(spatial_rank, src) + dst = to_affine_nd(spatial_rank, dst) if dst is not None else src + dst, *_ = convert_to_dst_type(dst, dst, dtype=torch.float32) + + if np.allclose(src, dst, atol=AFFINE_TOL): + # no significant change, return original image + output_data, *_ = convert_to_dst_type(img, img, dtype=torch.float32) + return output_data, dst + + if has_nib and isinstance(img, np.ndarray): + spatial_ornt, dst_r = reorient_spatial_axes(img.shape[1 : spatial_rank + 1], src, dst) + if np.allclose(dst_r, dst, atol=AFFINE_TOL): + # simple reorientation achieves the desired affine + spatial_ornt[:, 0] += 1 + spatial_ornt = np.concatenate([np.array([[0, 1]]), spatial_ornt]) + img_ = nib.orientations.apply_orientation(img, spatial_ornt) + output_data, *_ = convert_to_dst_type(img_, img, dtype=torch.float32) + return output_data, dst + + transform = np.linalg.inv(src) @ dst + transform = to_affine_nd(spatial_rank, transform) + # no resampling if it's identity transform + if np.allclose(transform, np.diag(np.ones(len(transform))), atol=AFFINE_TOL): + output_data, *_ = convert_to_dst_type(img, img, dtype=torch.float32) + return output_data, dst + + _dtype = dtype or self.dtype or img.dtype + if ensure_tuple(spatial_size)[0] == -1: # if the spatial_size == -1 + spatial_size = img.shape[1 : spatial_rank + 1] + elif spatial_size is None: + spatial_size, _ = compute_shape_offset(img.shape[1 : spatial_rank + 1], src, dst) # type: ignore + # resample + img_, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype) # type: ignore + _align_corners = self.align_corners if align_corners is None else align_corners + if USE_COMPILED and _align_corners: + affine_xform = Affine( + affine=convert_data_type(transform, torch.Tensor, img_.device, dtype=_dtype)[0], + spatial_size=spatial_size, + image_only=True, + ) + output_data = affine_xform( + img_, + mode=look_up_option(mode or self.mode, GridSampleMode), + padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), + ) + else: + affine_xform = AffineTransform( + normalized=False, + mode=look_up_option(mode or self.mode, GridSampleMode), + padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), + align_corners=_align_corners, + reverse_indexing=True, + ) + output_data = affine_xform( + img_.unsqueeze(0), + theta=convert_data_type(transform, torch.Tensor, img_.device, dtype=_dtype)[0], + spatial_size=spatial_size, + ).squeeze(0) + # output dtype float + output_data, *_ = convert_to_dst_type(output_data, img, dtype=torch.float32) + return output_data, dst + + class Spacing(Transform): """ Resample input image into the specified `pixdim`. @@ -135,11 +272,14 @@ def __init__( """ self.pixdim = np.array(ensure_tuple(pixdim), dtype=np.float64) self.diagonal = diagonal - self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) - self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) - self.align_corners = align_corners - self.dtype = dtype self.image_only = image_only + self.dtype = dtype + + self.sp_resample = SpatialResample( + mode=look_up_option(mode, GridSampleMode), + padding_mode=look_up_option(padding_mode, GridSamplePadMode), + align_corners=align_corners, + ) def __call__( self, @@ -198,32 +338,16 @@ def __call__( new_affine = zoom_affine(affine_, out_d, diagonal=self.diagonal) output_shape, offset = compute_shape_offset(data_array.shape[1:], affine_, new_affine) new_affine[:sr, -1] = offset[:sr] - transform = np.linalg.inv(affine_) @ new_affine - # adapt to the actual rank - transform = to_affine_nd(sr, transform) - - # no resampling if it's identity transform - if np.allclose(transform, np.diag(np.ones(len(transform))), atol=1e-3): - output_data = data_array - else: - # resample - affine_xform = AffineTransform( - normalized=False, - mode=look_up_option(mode or self.mode, GridSampleMode), - padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), - align_corners=self.align_corners if align_corners is None else align_corners, - reverse_indexing=True, - ) - data_array_t: torch.Tensor - data_array_t, *_ = convert_data_type(data_array, torch.Tensor, dtype=_dtype) # type: ignore - output_data = affine_xform( - # AffineTransform requires a batch dim - data_array_t.unsqueeze(0), - convert_data_type(transform, torch.Tensor, data_array_t.device, dtype=_dtype)[0], - spatial_size=output_shape if output_spatial_shape is None else output_spatial_shape, - ).squeeze(0) - - output_data, *_ = convert_to_dst_type(output_data, data_array, dtype=torch.float32) + output_data, new_affine = self.sp_resample( + data_array, + src=affine, + dst=new_affine, + spatial_size=output_shape, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=_dtype, + ) new_affine = to_affine_nd(affine_np, new_affine) # type: ignore new_affine, *_ = convert_to_dst_type(src=new_affine, dst=affine, dtype=torch.float32) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 0ab9210aaf..4a8a4b663f 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -46,6 +46,7 @@ Rotate, Rotate90, Spacing, + SpatialResample, Zoom, ) from monai.transforms.transform import MapTransform, RandomizableTransform @@ -68,6 +69,7 @@ nib, _ = optional_import("nibabel") __all__ = [ + "SpatialResampled", "Spacingd", "Orientationd", "Rotate90d", @@ -86,6 +88,8 @@ "RandRotated", "Zoomd", "RandZoomd", + "SpatialResampleD", + "SpatialResampleDict", "SpacingD", "SpacingDict", "OrientationD", @@ -131,6 +135,10 @@ DEFAULT_POST_FIX = PostFix.meta() +class SpatialResampled(MapTransform, InvertibleTransform): + pass + + class Spacingd(MapTransform, InvertibleTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Spacing`. @@ -1870,6 +1878,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d +SpatialResampleD = SpatialResampleDict = SpatialResampled SpacingD = SpacingDict = Spacingd OrientationD = OrientationDict = Orientationd Rotate90D = Rotate90Dict = Rotate90d diff --git a/monai/utils/module.py b/monai/utils/module.py index c0fc10a7c0..8813301d8e 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -19,7 +19,7 @@ from pkgutil import walk_packages from re import match from types import FunctionType -from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Tuple, cast +from typing import Any, Callable, Collection, Hashable, Iterable, List, Mapping, Tuple, Union, cast import torch @@ -44,7 +44,7 @@ ] -def look_up_option(opt_str, supported: Collection, default="no_default"): +def look_up_option(opt_str, supported: Union[Collection, enum.EnumMeta], default="no_default"): """ Look up the option in the supported collection and return the matched item. Raise a value error possibly with a guess of the closest match. diff --git a/tests/test_spacing.py b/tests/test_spacing.py index ebff25712d..6a6bf4431f 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -78,7 +78,7 @@ TESTS.append( [ p, - {"pixdim": (1.0, 1.0)}, + {"pixdim": (1.0, 1.0), "align_corners": True}, np.arange(24).reshape((2, 3, 4)), # data {}, np.array( @@ -203,7 +203,7 @@ def test_spacing(self, in_type, init_param, img, data_param, expected_output): self.assertEqual(_img.device, output_data.device) output_data = output_data.cpu() - np.testing.assert_allclose(output_data, expected_output, atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(output_data, expected_output, atol=1e-1, rtol=1e-1) sr = len(output_data.shape) - 1 if isinstance(init_param["pixdim"], float): init_param["pixdim"] = [init_param["pixdim"]] * sr From 0f07fc922813caab82ec59d13bb02f68dd1b5e38 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 22 Jan 2022 15:30:08 +0000 Subject: [PATCH 02/33] fixes resampling Signed-off-by: Wenqi Li --- monai/csrc/resample/pushpull_cpu.cpp | 97 ++++++------------------- monai/csrc/resample/pushpull_cuda.cu | 104 ++++++--------------------- 2 files changed, 45 insertions(+), 156 deletions(-) diff --git a/monai/csrc/resample/pushpull_cpu.cpp b/monai/csrc/resample/pushpull_cpu.cpp index d83557c6c3..4c9f6b9f9c 100644 --- a/monai/csrc/resample/pushpull_cpu.cpp +++ b/monai/csrc/resample/pushpull_cpu.cpp @@ -1527,10 +1527,10 @@ MONAI_NAMESPACE_DEVICE { // cpu iy0 = bound::index(bound1, iy0, src_Y); iz0 = bound::index(bound2, iz0, src_Z); - // Offsets into source volume offset_t o000, o100, o010, o001, o110, o011, o101, o111; if (do_pull || do_grad || do_sgrad) { + // Offsets into source volume o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; @@ -1539,18 +1539,20 @@ MONAI_NAMESPACE_DEVICE { // cpu o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; + } else { + // Offsets into 'push' volume + o000 = ix0 * out_sX + iy0 * out_sY + iz0 * out_sZ; + o100 = ix1 * out_sX + iy0 * out_sY + iz0 * out_sZ; + o010 = ix0 * out_sX + iy1 * out_sY + iz0 * out_sZ; + o001 = ix0 * out_sX + iy0 * out_sY + iz1 * out_sZ; + o110 = ix1 * out_sX + iy1 * out_sY + iz0 * out_sZ; + o011 = ix0 * out_sX + iy1 * out_sY + iz1 * out_sZ; + o101 = ix1 * out_sX + iy0 * out_sY + iz1 * out_sZ; + o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { - o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; - o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; scalar_t gx = static_cast(0); scalar_t gy = static_cast(0); scalar_t gz = static_cast(0); @@ -1659,14 +1661,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { - o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; - o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ; scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC) { @@ -1678,14 +1672,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~ else if (do_sgrad) { - o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; - o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -1758,16 +1744,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_count) { - // Offsets into 'push' volume - o000 = ix0 * out_sX + iy0 * out_sY + iz0 * out_sZ; - o100 = ix1 * out_sX + iy0 * out_sY + iz0 * out_sZ; - o010 = ix0 * out_sX + iy1 * out_sY + iz0 * out_sZ; - o001 = ix0 * out_sX + iy0 * out_sY + iz1 * out_sZ; - o110 = ix1 * out_sX + iy1 * out_sY + iz0 * out_sZ; - o011 = ix0 * out_sX + iy1 * out_sY + iz1 * out_sZ; - o101 = ix1 * out_sX + iy0 * out_sY + iz1 * out_sZ; - o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ; - scalar_t* out_ptr_N = out_ptr + n * out_sN; bound::add(out_ptr_N, o000, w000, s000); bound::add(out_ptr_N, o100, w100, s100); @@ -1822,21 +1798,23 @@ MONAI_NAMESPACE_DEVICE { // cpu ix0 = bound::index(bound0, ix0, src_X); iy0 = bound::index(bound1, iy0, src_Y); - // Offsets into source volume offset_t o00, o10, o01, o11; if (do_pull || do_grad || do_sgrad) { + // Offsets into source volume o00 = ix0 * src_sX + iy0 * src_sY; o10 = ix1 * src_sX + iy0 * src_sY; o01 = ix0 * src_sX + iy1 * src_sY; o11 = ix1 * src_sX + iy1 * src_sY; + } else { + // Offsets into 'push' volume + o00 = ix0 * out_sX + iy0 * out_sY; + o10 = ix1 * out_sX + iy0 * out_sY; + o01 = ix0 * out_sX + iy1 * out_sY; + o11 = ix1 * out_sX + iy1 * out_sY; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { - o00 = ix0 * src_sX + iy0 * src_sY; - o10 = ix1 * src_sX + iy0 * src_sY; - o01 = ix0 * src_sX + iy1 * src_sY; - o11 = ix1 * src_sX + iy1 * src_sY; scalar_t gx = static_cast(0); scalar_t gy = static_cast(0); scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY; @@ -1895,10 +1873,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { - o00 = ix0 * src_sX + iy0 * src_sY; - o10 = ix1 * src_sX + iy0 * src_sY; - o01 = ix0 * src_sX + iy1 * src_sY; - o11 = ix1 * src_sX + iy1 * src_sY; scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY; scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC) { @@ -1908,10 +1882,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_sgrad) { - o00 = ix0 * src_sX + iy0 * src_sY; - o10 = ix1 * src_sX + iy0 * src_sY; - o01 = ix0 * src_sX + iy1 * src_sY; - o11 = ix1 * src_sX + iy1 * src_sY; scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -1926,11 +1896,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - // Offsets into 'push' volume - o00 = ix0 * out_sX + iy0 * out_sY; - o10 = ix1 * out_sX + iy0 * out_sY; - o01 = ix0 * out_sX + iy1 * out_sY; - o11 = ix1 * out_sX + iy1 * out_sY; scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY; scalar_t* out_ptr_NC = out_ptr + n * out_sN; if (trgt_K == 0) { @@ -1960,12 +1925,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_count) { - // Offsets into 'push' volume - o00 = ix0 * out_sX + iy0 * out_sY; - o10 = ix1 * out_sX + iy0 * out_sY; - o01 = ix0 * out_sX + iy1 * out_sY; - o11 = ix1 * out_sX + iy1 * out_sY; - scalar_t* out_ptr_N = out_ptr + n * out_sN; bound::add(out_ptr_N, o00, w00, s00); bound::add(out_ptr_N, o10, w10, s10); @@ -1996,20 +1955,21 @@ MONAI_NAMESPACE_DEVICE { // cpu ix1 = bound::index(bound0, ix0 + 1, src_X); ix0 = bound::index(bound0, ix0, src_X); - // Offsets into source volume offset_t o0, o1; if (do_pull || do_grad || do_sgrad) { + // Offsets into source volume o0 = ix0 * src_sX; o1 = ix1 * src_sX; + } else { + // Offsets into 'push' volume + o0 = ix0 * out_sX; + o1 = ix1 * out_sX; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { if (trgt_K == 0) { // backward w.r.t. push/pull - - o0 = ix0 * src_sX; - o1 = ix1 * src_sX; scalar_t gx = static_cast(0); scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -2037,8 +1997,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { - o0 = ix0 * src_sX; - o1 = ix1 * src_sX; scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) { @@ -2047,8 +2005,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_sgrad) { - o0 = ix0 * src_sX; - o1 = ix1 * src_sX; scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -2058,9 +2014,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - // Offsets into 'push' volume - o0 = ix0 * out_sX; - o1 = ix1 * out_sX; scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; scalar_t* out_ptr_NC = out_ptr + n * out_sN; if (trgt_K == 0) { @@ -2081,10 +2034,6 @@ MONAI_NAMESPACE_DEVICE { // cpu } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_count) { - // Offsets into 'push' volume - o0 = ix0 * out_sX; - o1 = ix1 * out_sX; - scalar_t* out_ptr_N = out_ptr + n * out_sN; bound::add(out_ptr_N, o0, w0, s0); bound::add(out_ptr_N, o1, w1, s1); diff --git a/monai/csrc/resample/pushpull_cuda.cu b/monai/csrc/resample/pushpull_cuda.cu index 4a2d6c27ef..fe50154670 100644 --- a/monai/csrc/resample/pushpull_cuda.cu +++ b/monai/csrc/resample/pushpull_cuda.cu @@ -1491,10 +1491,10 @@ MONAI_NAMESPACE_DEVICE { // cuda iy0 = bound::index(bound1, iy0, src_Y); iz0 = bound::index(bound2, iz0, src_Z); - // Offsets into source volume offset_t o000, o100, o010, o001, o110, o011, o101, o111; if (do_pull || do_grad || do_sgrad) { + // Offsets into source volume o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; @@ -1503,18 +1503,20 @@ MONAI_NAMESPACE_DEVICE { // cuda o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; + } else { + // Offsets into 'push' volume + o000 = ix0 * out_sX + iy0 * out_sY + iz0 * out_sZ; + o100 = ix1 * out_sX + iy0 * out_sY + iz0 * out_sZ; + o010 = ix0 * out_sX + iy1 * out_sY + iz0 * out_sZ; + o001 = ix0 * out_sX + iy0 * out_sY + iz1 * out_sZ; + o110 = ix1 * out_sX + iy1 * out_sY + iz0 * out_sZ; + o011 = ix0 * out_sX + iy1 * out_sY + iz1 * out_sZ; + o101 = ix1 * out_sX + iy0 * out_sY + iz1 * out_sZ; + o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { - o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; - o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; scalar_t gx = static_cast(0); scalar_t gy = static_cast(0); scalar_t gz = static_cast(0); @@ -1623,14 +1625,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { - o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; - o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ; scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC) { @@ -1642,14 +1636,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~ else if (do_sgrad) { - o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ; - o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ; - o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ; - o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ; - o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ; scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -1672,15 +1658,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - // Offsets into 'push' volume - o000 = ix0 * out_sX + iy0 * out_sY + iz0 * out_sZ; - o100 = ix1 * out_sX + iy0 * out_sY + iz0 * out_sZ; - o010 = ix0 * out_sX + iy1 * out_sY + iz0 * out_sZ; - o001 = ix0 * out_sX + iy0 * out_sY + iz1 * out_sZ; - o110 = ix1 * out_sX + iy1 * out_sY + iz0 * out_sZ; - o011 = ix0 * out_sX + iy1 * out_sY + iz1 * out_sZ; - o101 = ix1 * out_sX + iy0 * out_sY + iz1 * out_sZ; - o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ; scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ; scalar_t* out_ptr_NC = out_ptr + n * out_sN; if (trgt_K == 0) { @@ -1722,16 +1699,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_count) { - // Offsets into 'push' volume - o000 = ix0 * out_sX + iy0 * out_sY + iz0 * out_sZ; - o100 = ix1 * out_sX + iy0 * out_sY + iz0 * out_sZ; - o010 = ix0 * out_sX + iy1 * out_sY + iz0 * out_sZ; - o001 = ix0 * out_sX + iy0 * out_sY + iz1 * out_sZ; - o110 = ix1 * out_sX + iy1 * out_sY + iz0 * out_sZ; - o011 = ix0 * out_sX + iy1 * out_sY + iz1 * out_sZ; - o101 = ix1 * out_sX + iy0 * out_sY + iz1 * out_sZ; - o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ; - scalar_t* out_ptr_N = out_ptr + n * out_sN; bound::add(out_ptr_N, o000, w000, s000); bound::add(out_ptr_N, o100, w100, s100); @@ -1786,21 +1753,23 @@ MONAI_NAMESPACE_DEVICE { // cuda ix0 = bound::index(bound0, ix0, src_X); iy0 = bound::index(bound1, iy0, src_Y); - // Offsets into source volume offset_t o00, o10, o01, o11; if (do_pull || do_grad || do_sgrad) { + // Offsets into source volume o00 = ix0 * src_sX + iy0 * src_sY; o10 = ix1 * src_sX + iy0 * src_sY; o01 = ix0 * src_sX + iy1 * src_sY; o11 = ix1 * src_sX + iy1 * src_sY; + } else { + // Offsets into 'push' volume + o00 = ix0 * out_sX + iy0 * out_sY; + o10 = ix1 * out_sX + iy0 * out_sY; + o01 = ix0 * out_sX + iy1 * out_sY; + o11 = ix1 * out_sX + iy1 * out_sY; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { - o00 = ix0 * src_sX + iy0 * src_sY; - o10 = ix1 * src_sX + iy0 * src_sY; - o01 = ix0 * src_sX + iy1 * src_sY; - o11 = ix1 * src_sX + iy1 * src_sY; scalar_t gx = static_cast(0); scalar_t gy = static_cast(0); scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY; @@ -1859,10 +1828,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { - o00 = ix0 * src_sX + iy0 * src_sY; - o10 = ix1 * src_sX + iy0 * src_sY; - o01 = ix0 * src_sX + iy1 * src_sY; - o11 = ix1 * src_sX + iy1 * src_sY; scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY; scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC) { @@ -1872,10 +1837,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_sgrad) { - o00 = ix0 * src_sX + iy0 * src_sY; - o10 = ix1 * src_sX + iy0 * src_sY; - o01 = ix0 * src_sX + iy1 * src_sY; - o11 = ix1 * src_sX + iy1 * src_sY; scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -1890,11 +1851,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - // Offsets into 'push' volume - o00 = ix0 * out_sX + iy0 * out_sY; - o10 = ix1 * out_sX + iy0 * out_sY; - o01 = ix0 * out_sX + iy1 * out_sY; - o11 = ix1 * out_sX + iy1 * out_sY; scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY; scalar_t* out_ptr_NC = out_ptr + n * out_sN; if (trgt_K == 0) { @@ -1924,12 +1880,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_count) { - // Offsets into 'push' volume - o00 = ix0 * out_sX + iy0 * out_sY; - o10 = ix1 * out_sX + iy0 * out_sY; - o01 = ix0 * out_sX + iy1 * out_sY; - o11 = ix1 * out_sX + iy1 * out_sY; - scalar_t* out_ptr_N = out_ptr + n * out_sN; bound::add(out_ptr_N, o00, w00, s00); bound::add(out_ptr_N, o10, w10, s10); @@ -1965,15 +1915,16 @@ MONAI_NAMESPACE_DEVICE { // cuda if (do_pull || do_grad || do_sgrad) { o0 = ix0 * src_sX; o1 = ix1 * src_sX; + } else { + // Offsets into 'push' volume + o0 = ix0 * out_sX; + o1 = ix1 * out_sX; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_grad) { if (trgt_K == 0) { // backward w.r.t. push/pull - - o0 = ix0 * src_sX; - o1 = ix1 * src_sX; scalar_t gx = static_cast(0); scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -2001,8 +1952,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if (do_pull) { - o0 = ix0 * src_sX; - o1 = ix1 * src_sX; scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; scalar_t* src_ptr_NC = src_ptr + n * src_sN; for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) { @@ -2011,8 +1960,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_sgrad) { - o0 = ix0 * src_sX; - o1 = ix1 * src_sX; scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX; scalar_t* src_ptr_NC = src_ptr + n * src_sN; @@ -2022,9 +1969,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_push) { - // Offsets into 'push' volume - o0 = ix0 * out_sX; - o1 = ix1 * out_sX; scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX; scalar_t* out_ptr_NC = out_ptr + n * out_sN; if (trgt_K == 0) { @@ -2045,10 +1989,6 @@ MONAI_NAMESPACE_DEVICE { // cuda } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ else if (do_count) { - // Offsets into 'push' volume - o0 = ix0 * out_sX; - o1 = ix1 * out_sX; - scalar_t* out_ptr_N = out_ptr + n * out_sN; bound::add(out_ptr_N, o0, w0, s0); bound::add(out_ptr_N, o1, w1, s1); From b09e56825c03758633285ae80074f23d51ba7082 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 23 Jan 2022 00:35:08 +0000 Subject: [PATCH 03/33] fixes precisions Signed-off-by: Wenqi Li --- monai/data/utils.py | 7 +- monai/transforms/spatial/array.py | 123 +++++++++++++++++-------- monai/transforms/spatial/dictionary.py | 11 ++- monai/transforms/utils.py | 11 ++- tests/test_spacing.py | 11 ++- 5 files changed, 117 insertions(+), 46 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 2bf54cc6ab..94fe440f9d 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -700,14 +700,17 @@ def to_affine_nd(r: Union[np.ndarray, int], affine, dtype=np.float64) -> np.ndar """ Using elements from affine, to create a new affine matrix by assigning the rotation/zoom/scaling matrix and the translation vector. - when ``r`` is an integer, output is an (r+1)x(r+1) matrix, + + When ``r`` is an integer, output is an (r+1)x(r+1) matrix, where the top left kxk elements are copied from ``affine``, the last column of the output affine is copied from ``affine``'s last column. `k` is determined by `min(r, len(affine) - 1)`. - when ``r`` is an affine matrix, the output has the same shape as ``r``, + + When ``r`` is an affine matrix, the output has the same shape as ``r``, and the top left kxk elements are copied from ``affine``, the last column of the output affine is copied from ``affine``'s last column. `k` is determined by `min(len(r) - 1, len(affine) - 1)`. + Args: r (int or matrix): number of spatial dimensions or an output affine to be filled. affine (matrix): 2D affine matrix diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index afd8fdc249..3a545a6677 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -102,6 +102,9 @@ def __init__( mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample @@ -144,6 +147,9 @@ def __call__( mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample @@ -152,6 +158,9 @@ def __call__( dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If ``None``, use the data type of input data. To be compatible with other modules, the output data type is always `float32`. + + When both ``monai.config.USE_COMPILED`` and ``align_corners`` are set to ``True``, + MONAI's resampling implementation will be used. """ spatial_rank = min(len(img.shape) - 1, 3) if src is None: @@ -183,37 +192,37 @@ def __call__( return output_data, dst _dtype = dtype or self.dtype or img.dtype - if ensure_tuple(spatial_size)[0] == -1: # if the spatial_size == -1 + spatial_size = ensure_tuple(spatial_size) + if spatial_size[0] == -1: # if the spatial_size == -1 spatial_size = img.shape[1 : spatial_rank + 1] - elif spatial_size is None: + elif spatial_size[0] is None: spatial_size, _ = compute_shape_offset(img.shape[1 : spatial_rank + 1], src, dst) # type: ignore + spatial_size = spatial_size[:spatial_rank] + chns, additional_dims = img.shape[0], img.shape[spatial_rank + 1 :] # beyond three spatial dims # resample - img_, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype) # type: ignore - _align_corners = self.align_corners if align_corners is None else align_corners - if USE_COMPILED and _align_corners: - affine_xform = Affine( - affine=convert_data_type(transform, torch.Tensor, img_.device, dtype=_dtype)[0], - spatial_size=spatial_size, - image_only=True, - ) - output_data = affine_xform( - img_, - mode=look_up_option(mode or self.mode, GridSampleMode), - padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), - ) + img_ = convert_data_type(img, torch.Tensor, dtype=_dtype)[0] # type: ignore + transform = convert_to_dst_type(transform, img_)[0] # type: ignore + align_corners = self.align_corners if align_corners is None else align_corners + mode = look_up_option(mode or self.mode, GridSampleMode) + padding_mode = look_up_option(padding_mode or self.padding_mode, GridSamplePadMode) + if additional_dims: + xform_shape = [-1] + list(img.shape[1 : spatial_rank + 1]) + img_ = img_.reshape(xform_shape) + if USE_COMPILED and align_corners: + affine_xform = Affine(affine=transform, spatial_size=spatial_size, image_only=True, dtype=_dtype) + output_data = affine_xform(img_, mode=mode, padding_mode=padding_mode) else: affine_xform = AffineTransform( normalized=False, - mode=look_up_option(mode or self.mode, GridSampleMode), - padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), - align_corners=_align_corners, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, reverse_indexing=True, ) - output_data = affine_xform( - img_.unsqueeze(0), - theta=convert_data_type(transform, torch.Tensor, img_.device, dtype=_dtype)[0], - spatial_size=spatial_size, - ).squeeze(0) + output_data = affine_xform(img_.unsqueeze(0), theta=transform, spatial_size=spatial_size).squeeze(0) + if additional_dims: + full_shape = tuple([chns, *spatial_size, *additional_dims]) + output_data = output_data.reshape(full_shape) # output dtype float output_data, *_ = convert_to_dst_type(output_data, img, dtype=torch.float32) return output_data, dst @@ -259,6 +268,9 @@ def __init__( mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample @@ -273,12 +285,12 @@ def __init__( self.pixdim = np.array(ensure_tuple(pixdim), dtype=np.float64) self.diagonal = diagonal self.image_only = image_only - self.dtype = dtype self.sp_resample = SpatialResample( mode=look_up_option(mode, GridSampleMode), padding_mode=look_up_option(padding_mode, GridSamplePadMode), align_corners=align_corners, + dtype=dtype, ) def __call__( @@ -298,6 +310,9 @@ def __call__( mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample @@ -318,7 +333,6 @@ def __call__( data_array (resampled into `self.pixdim`), original affine, current affine. """ - _dtype = dtype or self.dtype or data_array.dtype sr = int(data_array.ndim - 1) if sr <= 0: raise ValueError("data_array must have at least one spatial dimension.") @@ -342,11 +356,11 @@ def __call__( data_array, src=affine, dst=new_affine, - spatial_size=output_shape, + spatial_size=output_shape if output_spatial_shape is None else output_spatial_shape, mode=mode, padding_mode=padding_mode, align_corners=align_corners, - dtype=_dtype, + dtype=dtype, ) new_affine = to_affine_nd(affine_np, new_affine) # type: ignore new_affine, *_ = convert_to_dst_type(src=new_affine, dst=affine, dtype=torch.float32) @@ -1209,6 +1223,9 @@ class AffineGrid(Transform): pixel/voxel relative to the center of the input image. Defaults to no translation. scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D, a tuple of 3 floats for 3D. Defaults to `1.0`. + dtype: data type for the grid computation. Defaults to ``np.float32``. + If ``None``, use the data type of input data (if `grid` is provided). + device: device on which the tensor will be allocated, if a new grid is generated. affine: If applied, ignore the params (`rotate_params`, etc.) and use the supplied matrix. Should be square with each side = num of image spatial dimensions + 1. @@ -1229,6 +1246,7 @@ def __init__( scale_params: Optional[Union[Sequence[float], float]] = None, as_tensor_output: bool = True, device: Optional[torch.device] = None, + dtype: DtypeLike = np.float32, affine: Optional[NdarrayOrTensor] = None, ) -> None: self.rotate_params = rotate_params @@ -1236,6 +1254,7 @@ def __init__( self.translate_params = translate_params self.scale_params = scale_params self.device = device + self.dtype = dtype self.affine = affine def __call__( @@ -1256,7 +1275,7 @@ def __call__( """ if grid is None: if spatial_size is not None: - grid = create_grid(spatial_size, device=self.device, backend="torch") + grid = create_grid(spatial_size, device=self.device, backend="torch", dtype=self.dtype) else: raise ValueError("Incompatible values: grid=None and spatial_size=None.") @@ -1281,7 +1300,7 @@ def __call__( else: affine = self.affine - grid, *_ = convert_data_type(grid, torch.Tensor, device=_device, dtype=float) + grid, *_ = convert_data_type(grid, torch.Tensor, device=_device, dtype=self.dtype) affine, *_ = convert_to_dst_type(affine, grid) grid = (affine @ grid.reshape((grid.shape[0], -1))).reshape([-1] + list(grid.shape[1:])) @@ -1464,6 +1483,7 @@ def __init__( padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, as_tensor_output: bool = True, device: Optional[torch.device] = None, + dtype: DtypeLike = np.float32, ) -> None: """ computes output image using values from `img`, locations from `grid` using pytorch. @@ -1476,7 +1496,13 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull device: device on which the tensor will be allocated. + dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + If ``None``, use the data type of input data. To be compatible with other modules, + the output data type is always `float32`. .. deprecated:: 0.6.0 ``as_tensor_output`` is deprecated. @@ -1485,6 +1511,7 @@ def __init__( self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) self.device = device + self.dtype = dtype def __call__( self, @@ -1492,6 +1519,7 @@ def __call__( grid: Optional[NdarrayOrTensor] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + dtype: DtypeLike = np.float64, ) -> NdarrayOrTensor: """ Args: @@ -1500,16 +1528,22 @@ def __call__( mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + If ``None``, use the data type of input data. To be compatible with other modules, + the output data type is always `float32`. """ if grid is None: raise ValueError("Unknown grid.") _device = img.device if isinstance(img, torch.Tensor) else self.device img_t: torch.Tensor grid_t: torch.Tensor - img_t, *_ = convert_data_type(img, torch.Tensor, device=_device, dtype=torch.float32) # type: ignore + img_t, *_ = convert_data_type(img, torch.Tensor, device=_device, dtype=dtype) # type: ignore grid_t, *_ = convert_to_dst_type(grid, img_t) # type: ignore if USE_COMPILED: @@ -1525,14 +1559,16 @@ def __call__( elif _padding_mode == "border": bound = 0 else: - bound = 1 + bound = 2 # dct2 is reflection with align_corners=True _interp_mode = look_up_option(self.mode if mode is None else mode, GridSampleMode).value + if _interp_mode == "nearest": + _interp = 0 + elif _interp_mode == "bicubic": + _interp = 3 + else: + _interp = 1 # "bilinear" out = grid_pull( - img_t.unsqueeze(0), - grid_t.unsqueeze(0), - bound=bound, - extrapolate=True, - interpolation=1 if _interp_mode == "bilinear" else _interp_mode, + img_t.unsqueeze(0), grid_t.unsqueeze(0), bound=bound, extrapolate=True, interpolation=_interp )[0] else: for i, dim in enumerate(img_t.shape[1:]): @@ -1549,7 +1585,7 @@ def __call__( align_corners=True, )[0] out_val: NdarrayOrTensor - out_val, *_ = convert_to_dst_type(out, dst=img, dtype=out.dtype) + out_val, *_ = convert_to_dst_type(out, dst=img, dtype=np.float32) return out_val @@ -1575,6 +1611,7 @@ def __init__( padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, as_tensor_output: bool = True, device: Optional[torch.device] = None, + dtype: DtypeLike = np.float32, image_only: bool = False, ) -> None: """ @@ -1609,10 +1646,16 @@ def __init__( mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample device: device on which the tensor will be allocated. + dtype: data type for resampling computation. Defaults to ``np.float32``. + If ``None``, use the data type of input data. To be compatible with other modules, + the output data type is always `float32`. image_only: if True return only the image volume, otherwise return (image, affine). .. deprecated:: 0.6.0 @@ -1625,10 +1668,11 @@ def __init__( translate_params=translate_params, scale_params=scale_params, affine=affine, + dtype=dtype, device=device, ) self.image_only = image_only - self.resampler = Resample(device=device) + self.resampler = Resample(device=device, dtype=dtype) self.spatial_size = spatial_size self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) @@ -1651,6 +1695,9 @@ def __call__( mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 4a8a4b663f..e9c6178183 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -136,7 +136,7 @@ class SpatialResampled(MapTransform, InvertibleTransform): - pass + backend = SpatialResample.backend class Spacingd(MapTransform, InvertibleTransform): @@ -615,6 +615,7 @@ def __init__( padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, as_tensor_output: bool = True, device: Optional[torch.device] = None, + dtype: Union[DtypeLike, torch.dtype] = np.float32, allow_missing_keys: bool = False, ) -> None: """ @@ -654,6 +655,9 @@ def __init__( See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample It also can be a sequence of string, each element corresponds to a key in ``keys``. device: device on which the tensor will be allocated. + dtype: data type for resampling computation. Defaults to ``np.float32``. + If ``None``, use the data type of input data. To be compatible with other modules, + the output data type is always `float32`. allow_missing_keys: don't raise exception if key is missing. See also: @@ -673,6 +677,7 @@ def __init__( affine=affine, spatial_size=spatial_size, device=device, + dtype=dtype, ) self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) @@ -1345,7 +1350,7 @@ def __init__( mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, align_corners: Union[Sequence[bool], bool] = False, - dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], Union[DtypeLike, torch.dtype]] = np.float64, + dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], DtypeLike, torch.dtype] = np.float64, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) @@ -1459,7 +1464,7 @@ def __init__( mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, align_corners: Union[Sequence[bool], bool] = False, - dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], Union[DtypeLike, torch.dtype]] = np.float64, + dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], DtypeLike, torch.dtype] = np.float64, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 8265cd2a72..5690c2e515 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -48,6 +48,7 @@ ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, + get_equivalent_dtype, issequenceiterable, look_up_option, min_version, @@ -604,7 +605,7 @@ def _create_grid_numpy( """ spacing = spacing or tuple(1.0 for _ in spatial_size) ranges = [np.linspace(-(d - 1.0) / 2.0 * s, (d - 1.0) / 2.0 * s, int(d)) for d, s in zip(spatial_size, spacing)] - coords = np.asarray(np.meshgrid(*ranges, indexing="ij"), dtype=dtype) + coords = np.asarray(np.meshgrid(*ranges, indexing="ij"), dtype=get_equivalent_dtype(dtype, np.ndarray)) if not homogeneous: return coords return np.concatenate([coords, np.ones_like(coords[:1])]) @@ -622,7 +623,13 @@ def _create_grid_torch( """ spacing = spacing or tuple(1.0 for _ in spatial_size) ranges = [ - torch.linspace(-(d - 1.0) / 2.0 * s, (d - 1.0) / 2.0 * s, int(d), device=device, dtype=dtype) + torch.linspace( + -(d - 1.0) / 2.0 * s, + (d - 1.0) / 2.0 * s, + int(d), + device=device, + dtype=get_equivalent_dtype(dtype, torch.Tensor), + ) for d, s in zip(spatial_size, spacing) ] coords = meshgrid_ij(*ranges) diff --git a/tests/test_spacing.py b/tests/test_spacing.py index 6a6bf4431f..e4b5dbb258 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -192,6 +192,15 @@ np.array([[[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]]]), ] ) + TESTS.append( # 5D input + [ + p, + {"pixdim": [-1, -1, 0.5], "padding_mode": "zeros", "dtype": float}, + np.ones((1, 2, 1, 2, 1)), # data + {"affine": np.eye(4)}, + np.array([[[[[1.0], [1.0], [1.0]]], [[[1.0], [1.0], [1.0]]]]]), + ] + ) class TestSpacingCase(unittest.TestCase): @@ -204,7 +213,7 @@ def test_spacing(self, in_type, init_param, img, data_param, expected_output): output_data = output_data.cpu() np.testing.assert_allclose(output_data, expected_output, atol=1e-1, rtol=1e-1) - sr = len(output_data.shape) - 1 + sr = min(len(output_data.shape) - 1, 3) if isinstance(init_param["pixdim"], float): init_param["pixdim"] = [init_param["pixdim"]] * sr init_pixdim = ensure_tuple(init_param["pixdim"]) From 0b4bc52a502de034966111950eadb7a0332f2916 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 23 Jan 2022 17:41:05 +0000 Subject: [PATCH 04/33] update dict version Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 6 +- monai/transforms/spatial/dictionary.py | 143 +++++++++++++++++++++++++ tests/test_spacing.py | 2 +- 3 files changed, 148 insertions(+), 3 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 3a545a6677..5e3e0e53b2 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -90,6 +90,8 @@ class SpatialResample(Transform): Resample input image from the orientation/spacing defined by ``src`` affine into the ones specified by `dst`. """ + backend = [TransformBackends.TORCH] + def __init__( self, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, @@ -221,7 +223,7 @@ def __call__( ) output_data = affine_xform(img_.unsqueeze(0), theta=transform, spatial_size=spatial_size).squeeze(0) if additional_dims: - full_shape = tuple([chns, *spatial_size, *additional_dims]) + full_shape = (chns, *spatial_size, *additional_dims) output_data = output_data.reshape(full_shape) # output dtype float output_data, *_ = convert_to_dst_type(output_data, img, dtype=torch.float32) @@ -233,7 +235,7 @@ class Spacing(Transform): Resample input image into the specified `pixdim`. """ - backend = [TransformBackends.TORCH] + backend = SpatialResample.backend def __init__( self, diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index e9c6178183..0ce2031bc6 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -136,8 +136,151 @@ class SpatialResampled(MapTransform, InvertibleTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.SpatialResample`. + + This transform assumes the ``data`` dictionary has a key for the input + data's metadata and contains ``src`` and ``dst`` affine required by `SpatialResample`. + The key is formed by ``key_{meta_key_postfix}``. + + see also: + :py:class:`monai.transforms.SpatialResample` + """ + backend = SpatialResample.backend + def __init__( + self, + keys: KeysCollection, + mode: GridSampleModeSequence = GridSampleMode.BILINEAR, + padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, + align_corners: Union[Sequence[bool], bool] = False, + dtype: Optional[Union[Sequence[DtypeLike], DtypeLike]] = np.float64, + meta_keys: Optional[KeysCollection] = None, + meta_key_postfix: str = DEFAULT_POST_FIX, + meta_src_keys: Optional[KeysCollection] = "src_affine", + meta_dst_keys: Optional[KeysCollection] = "dst_affine", + allow_missing_keys: bool = False, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of string, each element corresponds to a key in ``keys``. + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of string, each element corresponds to a key in ``keys``. + align_corners: Geometrically, we consider the pixels of the input as squares rather than points. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + It also can be a sequence of bool, each element corresponds to a key in ``keys``. + dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. + If None, use the data type of input data. To be compatible with other modules, + the output data type is always ``np.float32``. + It also can be a sequence of dtypes, each element corresponds to a key in ``keys``. + meta_keys: explicitly indicate the key of the corresponding meta data dictionary. + for example, for data with key `image`, the metadata by default is in `image_meta_dict`. + the meta data is a dictionary object which contains: filename, affine, original_shape, etc. + it can be a sequence of string, map to the `keys`. + if None, will try to construct meta_keys by `key_{meta_key_postfix}`. + meta_key_postfix: if meta_keys=None, use `key_{postfix}` to to fetch the meta data according + to the key data, default is `meta_dict`, the meta data is a dictionary object. + For example, to handle key `image`, read/write affine matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + meta_src_keys: the key of the corresponding ``src`` affine in the meta data dictionary. + meta_dst_keys: the key of the corresponding ``dst`` affine in the meta data dictionary. + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.sp_transform = SpatialResample() + self.mode = ensure_tuple_rep(mode, len(self.keys)) + self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) + self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) + self.dtype = ensure_tuple_rep(dtype, len(self.keys)) + self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) + if len(self.keys) != len(self.meta_keys): + raise ValueError("meta_keys should have the same length as keys.") + self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) + self.meta_src_keys = ensure_tuple_rep(meta_src_keys, len(self.keys)) + self.meta_dst_keys = ensure_tuple_rep(meta_dst_keys, len(self.keys)) + + def __call__( + self, data: Mapping[Union[Hashable, str], Dict[str, NdarrayOrTensor]] + ) -> Dict[Hashable, NdarrayOrTensor]: + d: Dict = dict(data) + for (key, mode, padding_mode, align_corners, dtype, *metakeyinfo) in self.key_iterator( + d, + self.mode, + self.padding_mode, + self.align_corners, + self.dtype, + self.meta_keys, + self.meta_key_postfix, + self.meta_src_keys, + self.meta_dst_keys, + ): + meta_key, meta_key_postfix, meta_src_key, meta_dst_key = metakeyinfo + meta_key = meta_key or f"{key}_{meta_key_postfix}" + # create metadata if necessary + if meta_key not in d: + d[meta_key] = {meta_src_key: None, meta_dst_key: None} + meta_data = d[meta_key] + original_spatial_shape = d[key].shape[1:] + d[key], meta_data[meta_dst_key] = self.sp_transform( # write dst affine because the dtype might change + img=d[key], + src=meta_data[meta_src_key], + dst=meta_data[meta_dst_key], + spatial_size=None, # None means shape auto inferred + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + dtype=dtype, + ) + self.push_transform( + d, + key, + extra_info={ + "meta_key": meta_key, + "meta_src_key": meta_src_key, + "meta_dst_key": meta_dst_key, + "mode": mode.value if isinstance(mode, Enum) else mode, + "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, + }, + orig_size=original_spatial_shape, + ) + + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = deepcopy(dict(data)) + for key, dtype in self.key_iterator(d, self.dtype): + transform = self.get_most_recent_transform(d, key) + # Create inverse transform + meta_data = d[transform[TraceKeys.EXTRA_INFO]["meta_key"]] + src_affine = meta_data[d[transform[TraceKeys.EXTRA_INFO]["meta_src_key"]]] + dst_affine = meta_data[d[transform[TraceKeys.EXTRA_INFO]["meta_dst_key"]]] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] + orig_size = transform[TraceKeys.ORIG_SIZE] + inverse_transform = SpatialResample() + # Apply inverse + d[key], _ = inverse_transform( + img=d[key], + src=dst_affine, + dst=src_affine, + mode=mode, + padding_mode=padding_mode, + align_corners=False if align_corners == TraceKeys.NONE else align_corners, + dtype=dtype, + spatial_size=orig_size, + ) + # Remove the applied transform + self.pop_transform(d, key) + + return d + class Spacingd(MapTransform, InvertibleTransform): """ diff --git a/tests/test_spacing.py b/tests/test_spacing.py index e4b5dbb258..6d5babaa68 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -195,7 +195,7 @@ TESTS.append( # 5D input [ p, - {"pixdim": [-1, -1, 0.5], "padding_mode": "zeros", "dtype": float}, + {"pixdim": [-1, -1, 0.5], "padding_mode": "zeros", "dtype": float, "align_corners": True}, np.ones((1, 2, 1, 2, 1)), # data {"affine": np.eye(4)}, np.array([[[[[1.0], [1.0], [1.0]]], [[[1.0], [1.0], [1.0]]]]]), From 996dacfc3986b568e044ac5fd1dcd9ad78c3abf8 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 23 Jan 2022 19:45:20 +0000 Subject: [PATCH 05/33] fixes unit tests Signed-off-by: Wenqi Li --- monai/data/utils.py | 12 +++++++----- monai/transforms/spatial/array.py | 16 ++++++++-------- monai/transforms/spatial/dictionary.py | 6 +++--- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 94fe440f9d..4899695d03 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -27,7 +27,7 @@ import torch from torch.utils.data._utils.collate import default_collate -from monai.config.type_definitions import PathLike +from monai.config.type_definitions import NdarrayOrTensor, PathLike from monai.networks.layers.simplelayers import GaussianFilter from monai.utils import ( MAX_SEED, @@ -740,7 +740,7 @@ def to_affine_nd(r: Union[np.ndarray, int], affine, dtype=np.float64) -> np.ndar def reorient_spatial_axes( - data_shape: Sequence[int], init_affine: np.ndarray, target_affine: np.ndarray + data_shape: Sequence[int], init_affine: NdarrayOrTensor, target_affine: NdarrayOrTensor ) -> Tuple[np.ndarray, np.ndarray]: """ Given the input ``data_array`` and its corresponding coordinate ``init_affine``, @@ -748,10 +748,12 @@ def reorient_spatial_axes( Returns the transformed array and the updated affine. Note that this function requires external module ``nibabel.orientations``. """ - start_ornt = nib.orientations.io_orientation(init_affine) - target_ornt = nib.orientations.io_orientation(target_affine) + init_affine_, *_ = convert_data_type(init_affine, np.ndarray) + target_affine_, *_ = convert_data_type(target_affine, np.ndarray) + start_ornt = nib.orientations.io_orientation(init_affine_) + target_ornt = nib.orientations.io_orientation(target_affine_) ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) - new_affine = init_affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) + new_affine = init_affine_ @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) return ornt_transform, new_affine diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 5e3e0e53b2..10cb7ae685 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -127,11 +127,11 @@ def __call__( src: Optional[NdarrayOrTensor] = None, dst: Optional[NdarrayOrTensor] = None, spatial_size: Optional[Union[Sequence[int], int]] = None, - mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, - padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, - align_corners: bool = False, + mode: Union[GridSampleMode, str, None] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str, None] = GridSamplePadMode.BORDER, + align_corners: Optional[bool] = False, dtype: DtypeLike = np.float64, - ): + ) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: """ Args: img: input image to be resampled. It currently supports channel-first arrays with @@ -216,8 +216,8 @@ def __call__( else: affine_xform = AffineTransform( normalized=False, - mode=mode, - padding_mode=padding_mode, + mode=mode, # type: ignore + padding_mode=padding_mode, # type: ignore align_corners=align_corners, reverse_indexing=True, ) @@ -303,7 +303,7 @@ def __call__( padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, dtype: DtypeLike = None, - output_spatial_shape: Optional[np.ndarray] = None, + output_spatial_shape: Optional[Union[Sequence[int], int]] = None, ) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor]]: """ Args: @@ -358,7 +358,7 @@ def __call__( data_array, src=affine, dst=new_affine, - spatial_size=output_shape if output_spatial_shape is None else output_spatial_shape, + spatial_size=list(output_shape) if output_spatial_shape is None else output_spatial_shape, mode=mode, padding_mode=padding_mode, align_corners=align_corners, diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 0ce2031bc6..43c6c852d6 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -251,6 +251,7 @@ def __call__( }, orig_size=original_spatial_shape, ) + return d def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) @@ -258,8 +259,8 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd transform = self.get_most_recent_transform(d, key) # Create inverse transform meta_data = d[transform[TraceKeys.EXTRA_INFO]["meta_key"]] - src_affine = meta_data[d[transform[TraceKeys.EXTRA_INFO]["meta_src_key"]]] - dst_affine = meta_data[d[transform[TraceKeys.EXTRA_INFO]["meta_dst_key"]]] + src_affine = meta_data[d[transform[TraceKeys.EXTRA_INFO]["meta_src_key"]]] # type: ignore + dst_affine = meta_data[d[transform[TraceKeys.EXTRA_INFO]["meta_dst_key"]]] # type: ignore mode = transform[TraceKeys.EXTRA_INFO]["mode"] padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] @@ -278,7 +279,6 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd ) # Remove the applied transform self.pop_transform(d, key) - return d From 2790ca4f0eab601d40b092859ddb5caa3e08dfc4 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 23 Jan 2022 20:02:40 +0000 Subject: [PATCH 06/33] adds docs Signed-off-by: Wenqi Li --- docs/source/transforms.rst | 12 ++++++++++++ monai/data/utils.py | 3 +++ monai/transforms/__init__.py | 4 ++++ 3 files changed, 19 insertions(+) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 49ed4c9e6c..a8d04efaa5 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -557,6 +557,12 @@ Post-processing Spatial ^^^^^^^ +`SpatialResample` +""""""""""""""""" +.. autoclass:: SpatialResample + :members: + :special-members: __call__ + `Spacing` """"""""" .. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Spacing.png @@ -1398,6 +1404,12 @@ Post-processing (Dict) Spatial (Dict) ^^^^^^^^^^^^^^ +`SpatialResampled` +"""""""""""""""""" +.. autoclass:: SpatialResampled + :members: + :special-members: __call__ + `Spacingd` """""""""" .. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/Spacingd.png diff --git a/monai/data/utils.py b/monai/data/utils.py index 4899695d03..045eeaa79f 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -715,11 +715,14 @@ def to_affine_nd(r: Union[np.ndarray, int], affine, dtype=np.float64) -> np.ndar r (int or matrix): number of spatial dimensions or an output affine to be filled. affine (matrix): 2D affine matrix dtype: data type of the output array. + Raises: ValueError: When ``affine`` dimensions is not 2. ValueError: When ``r`` is nonpositive. + Returns: an (r+1) x (r+1) matrix + """ affine_np: np.ndarray affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True)[0] # type: ignore diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 9b779ed18e..e13fe138c1 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -306,6 +306,7 @@ Rotate, Rotate90, Spacing, + SpatialResample, Zoom, ) from .spatial.dictionary import ( @@ -360,6 +361,9 @@ Spacingd, SpacingD, SpacingDict, + SpatialResampled, + SpatialResampleD, + SpatialResampleDict, Zoomd, ZoomD, ZoomDict, From 4ef11bf2a5b717386555003bfb53f6a7dfa7da8d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 23 Jan 2022 20:32:51 +0000 Subject: [PATCH 07/33] copy grid for resampling Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 10cb7ae685..9ca8d050c4 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1546,7 +1546,9 @@ def __call__( img_t: torch.Tensor grid_t: torch.Tensor img_t, *_ = convert_data_type(img, torch.Tensor, device=_device, dtype=dtype) # type: ignore - grid_t, *_ = convert_to_dst_type(grid, img_t) # type: ignore + grid_t = convert_to_dst_type(grid, img_t)[0] # type: ignore + if grid_t is grid: # copy if needed + grid_t = grid_t.clone() if USE_COMPILED: for i, dim in enumerate(img_t.shape[1:]): From f65605ec400253886b41ca713a5b3867ff81fcfe Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 23 Jan 2022 21:14:47 +0000 Subject: [PATCH 08/33] fixes unit tests Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 9ca8d050c4..fcbc1e6896 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1563,7 +1563,7 @@ def __call__( elif _padding_mode == "border": bound = 0 else: - bound = 2 # dct2 is reflection with align_corners=True + bound = 1 # "relection" _interp_mode = look_up_option(self.mode if mode is None else mode, GridSampleMode).value if _interp_mode == "nearest": _interp = 0 From b3011a7c508127fa065475ce5ac4a94a7b9fe711 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 25 Jan 2022 10:11:17 +0000 Subject: [PATCH 09/33] remove normalize coordinates Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 48 ++++++++++++++++++++++++------- tests/test_spacing.py | 4 +-- 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index fcbc1e6896..1b603320f2 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -22,7 +22,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data.utils import compute_shape_offset, reorient_spatial_axes, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull -from monai.networks.utils import meshgrid_ij +from monai.networks.utils import meshgrid_ij, normalize_transform from monai.transforms.croppad.array import CenterSpatialCrop, Pad from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, Transform from monai.transforms.utils import ( @@ -210,8 +210,17 @@ def __call__( if additional_dims: xform_shape = [-1] + list(img.shape[1 : spatial_rank + 1]) img_ = img_.reshape(xform_shape) - if USE_COMPILED and align_corners: - affine_xform = Affine(affine=transform, spatial_size=spatial_size, image_only=True, dtype=_dtype) + if align_corners: + _t_r = torch.diag(torch.ones(len(transform), dtype=transform.dtype, device=transform.device)) + for idx, d_dst in enumerate(spatial_size[:spatial_rank]): + _t_r[idx, -1] = (max(d_dst, 2) - 1.0) / 2.0 + transform = transform @ _t_r + if not USE_COMPILED: + _t_l = normalize_transform(img.shape[1 : spatial_rank + 1], device=transform.device, dtype=transform.dtype, align_corners=True) + transform = _t_l @ transform + affine_xform = Affine( + affine=transform, spatial_size=spatial_size, norm_coords=False, image_only=True, dtype=_dtype + ) output_data = affine_xform(img_, mode=mode, padding_mode=padding_mode) else: affine_xform = AffineTransform( @@ -1484,6 +1493,7 @@ def __init__( mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, as_tensor_output: bool = True, + norm_coords: bool = True, device: Optional[torch.device] = None, dtype: DtypeLike = np.float32, ) -> None: @@ -1501,6 +1511,10 @@ def __init__( When `USE_COMPILED` is `True`, this argument uses ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. See also: https://docs.monai.io/en/stable/networks.html#grid-pull + norm_coords: whether to normalize the coordinates from `[-(size-1)/2, (size-1)/2]` to + `[0, size - 1]` (for ``monai/csrc`` implementation) or + `[-1, 1]` (for torch ``grid_sample`` implementation) to be compatible with the underlying + resampling API. device: device on which the tensor will be allocated. dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If ``None``, use the data type of input data. To be compatible with other modules, @@ -1512,6 +1526,7 @@ def __init__( """ self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) + self.norm_coords = norm_coords self.device = device self.dtype = dtype @@ -1549,11 +1564,15 @@ def __call__( grid_t = convert_to_dst_type(grid, img_t)[0] # type: ignore if grid_t is grid: # copy if needed grid_t = grid_t.clone() + sr = max(len(img_t.shape[1:]), 3) if USE_COMPILED: - for i, dim in enumerate(img_t.shape[1:]): - grid_t[i] += (dim - 1.0) / 2.0 - grid_t = grid_t[:-1] / grid_t[-1:] + if self.norm_coords: + for i, dim in enumerate(img_t.shape[1:]): + grid_t[i] += (max(dim, 2) - 1.0) / 2.0 + grid_t = grid_t[:sr] / grid_t[-1:] + else: + grid_t = grid_t[:sr] grid_t = grid_t.permute(list(range(grid_t.ndimension()))[1:] + [0]) _padding_mode = look_up_option( self.padding_mode if padding_mode is None else padding_mode, GridSamplePadMode @@ -1575,9 +1594,12 @@ def __call__( img_t.unsqueeze(0), grid_t.unsqueeze(0), bound=bound, extrapolate=True, interpolation=_interp )[0] else: - for i, dim in enumerate(img_t.shape[1:]): - grid_t[i] = 2.0 * grid_t[i] / (dim - 1.0) - grid_t = grid_t[:-1] / grid_t[-1:] + if self.norm_coords: + for i, dim in enumerate(img_t.shape[1:]): + grid_t[i] = 2.0 * grid_t[i] / (max(2, dim) - 1.0) + grid_t = grid_t[:sr] / grid_t[-1:] + else: + grid_t = grid_t[:sr] index_ordering: List[int] = list(range(img_t.ndimension() - 2, -1, -1)) grid_t = grid_t[index_ordering] grid_t = grid_t.permute(list(range(grid_t.ndimension()))[1:] + [0]) @@ -1613,6 +1635,7 @@ def __init__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION, + norm_coords: bool = True, as_tensor_output: bool = True, device: Optional[torch.device] = None, dtype: DtypeLike = np.float32, @@ -1656,6 +1679,11 @@ def __init__( padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + norm_coords: whether to normalize the coordinates from `[-(size-1)/2, (size-1)/2]` to + `[0, size - 1]` or `[-1, 1]` to be compatible with the underlying resampling API. + If the coordinates are generated by ``monai.transforms.utils.create_grid`` + and the ``affine`` doesn't include the normalization, this argument should be set to ``True``. + If the output `self.affine_grid` is already normalized, this argument should be set to ``False``. device: device on which the tensor will be allocated. dtype: data type for resampling computation. Defaults to ``np.float32``. If ``None``, use the data type of input data. To be compatible with other modules, @@ -1676,7 +1704,7 @@ def __init__( device=device, ) self.image_only = image_only - self.resampler = Resample(device=device, dtype=dtype) + self.resampler = Resample(norm_coords=norm_coords, device=device, dtype=dtype) self.spatial_size = spatial_size self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) diff --git a/tests/test_spacing.py b/tests/test_spacing.py index 6d5babaa68..0dd10a54f3 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -196,9 +196,9 @@ [ p, {"pixdim": [-1, -1, 0.5], "padding_mode": "zeros", "dtype": float, "align_corners": True}, - np.ones((1, 2, 1, 2, 1)), # data + np.ones((1, 2, 2, 2, 1)), # data {"affine": np.eye(4)}, - np.array([[[[[1.0], [1.0], [1.0]]], [[[1.0], [1.0], [1.0]]]]]), + np.ones((1, 2, 2, 3, 1)), ] ) From a1200dc27cad4ff970672be59383fdc282df52b1 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 25 Jan 2022 11:48:29 +0000 Subject: [PATCH 10/33] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/transforms/spatial/array.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 1b603320f2..c8aea0ec54 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -216,7 +216,9 @@ def __call__( _t_r[idx, -1] = (max(d_dst, 2) - 1.0) / 2.0 transform = transform @ _t_r if not USE_COMPILED: - _t_l = normalize_transform(img.shape[1 : spatial_rank + 1], device=transform.device, dtype=transform.dtype, align_corners=True) + _t_l = normalize_transform( + img.shape[1 : spatial_rank + 1], device=transform.device, dtype=transform.dtype, align_corners=True + ) transform = _t_l @ transform affine_xform = Affine( affine=transform, spatial_size=spatial_size, norm_coords=False, image_only=True, dtype=_dtype From 7b91be404a9fa0f8a283b48a552e2d1b998bb5bb Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 25 Jan 2022 10:54:52 +0000 Subject: [PATCH 11/33] try to fix #3621 (#3673) Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index c8aea0ec54..7bfae1d6e4 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1566,7 +1566,7 @@ def __call__( grid_t = convert_to_dst_type(grid, img_t)[0] # type: ignore if grid_t is grid: # copy if needed grid_t = grid_t.clone() - sr = max(len(img_t.shape[1:]), 3) + sr = min(len(img_t.shape[1:]), 3) if USE_COMPILED: if self.norm_coords: From ec9aaf37f36fc2f718ff4a8565f60478744b6700 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 25 Jan 2022 13:38:21 +0000 Subject: [PATCH 12/33] fixes typing Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 7bfae1d6e4..57befa8f38 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -186,42 +186,41 @@ def __call__( output_data, *_ = convert_to_dst_type(img_, img, dtype=torch.float32) return output_data, dst - transform = np.linalg.inv(src) @ dst - transform = to_affine_nd(spatial_rank, transform) + xform = np.linalg.inv(src) @ dst + xform = to_affine_nd(spatial_rank, xform) # no resampling if it's identity transform - if np.allclose(transform, np.diag(np.ones(len(transform))), atol=AFFINE_TOL): + if np.allclose(xform, np.diag(np.ones(len(xform))), atol=AFFINE_TOL): output_data, *_ = convert_to_dst_type(img, img, dtype=torch.float32) return output_data, dst _dtype = dtype or self.dtype or img.dtype spatial_size = ensure_tuple(spatial_size) + in_spatial_size = list(img.shape[1 : spatial_rank + 1]) if spatial_size[0] == -1: # if the spatial_size == -1 - spatial_size = img.shape[1 : spatial_rank + 1] + spatial_size = in_spatial_size elif spatial_size[0] is None: - spatial_size, _ = compute_shape_offset(img.shape[1 : spatial_rank + 1], src, dst) # type: ignore + spatial_size, _ = compute_shape_offset(in_spatial_size, src, dst) # type: ignore spatial_size = spatial_size[:spatial_rank] chns, additional_dims = img.shape[0], img.shape[spatial_rank + 1 :] # beyond three spatial dims # resample img_ = convert_data_type(img, torch.Tensor, dtype=_dtype)[0] # type: ignore - transform = convert_to_dst_type(transform, img_)[0] # type: ignore + xform = convert_to_dst_type(xform, img_)[0] # type: ignore align_corners = self.align_corners if align_corners is None else align_corners mode = look_up_option(mode or self.mode, GridSampleMode) padding_mode = look_up_option(padding_mode or self.padding_mode, GridSamplePadMode) if additional_dims: - xform_shape = [-1] + list(img.shape[1 : spatial_rank + 1]) + xform_shape = [-1] + in_spatial_size img_ = img_.reshape(xform_shape) if align_corners: - _t_r = torch.diag(torch.ones(len(transform), dtype=transform.dtype, device=transform.device)) + _t_r = torch.diag(torch.ones(len(xform), dtype=xform.dtype, device=xform.device)) # type: ignore for idx, d_dst in enumerate(spatial_size[:spatial_rank]): _t_r[idx, -1] = (max(d_dst, 2) - 1.0) / 2.0 - transform = transform @ _t_r + xform = xform @ _t_r if not USE_COMPILED: - _t_l = normalize_transform( - img.shape[1 : spatial_rank + 1], device=transform.device, dtype=transform.dtype, align_corners=True - ) - transform = _t_l @ transform + _t_l = normalize_transform(in_spatial_size, xform.device, xform.dtype, align_corners=True) # type: ignore + xform = _t_l @ xform affine_xform = Affine( - affine=transform, spatial_size=spatial_size, norm_coords=False, image_only=True, dtype=_dtype + affine=xform, spatial_size=spatial_size, norm_coords=False, image_only=True, dtype=_dtype ) output_data = affine_xform(img_, mode=mode, padding_mode=padding_mode) else: @@ -232,7 +231,7 @@ def __call__( align_corners=align_corners, reverse_indexing=True, ) - output_data = affine_xform(img_.unsqueeze(0), theta=transform, spatial_size=spatial_size).squeeze(0) + output_data = affine_xform(img_.unsqueeze(0), theta=xform, spatial_size=spatial_size).squeeze(0) if additional_dims: full_shape = (chns, *spatial_size, *additional_dims) output_data = output_data.reshape(full_shape) From 0ac197f69bf035d0c68115ffde29bc71d22428c2 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 25 Jan 2022 22:03:00 +0000 Subject: [PATCH 13/33] fixes grid_sample, interpolate URLs Signed-off-by: Wenqi Li --- monai/apps/deepgrow/transforms.py | 2 +- monai/data/nifti_saver.py | 6 +- monai/data/nifti_writer.py | 6 +- monai/data/png_saver.py | 2 +- monai/data/png_writer.py | 2 +- monai/handlers/segmentation_saver.py | 6 +- monai/networks/blocks/warp.py | 2 +- monai/networks/layers/spatial_transforms.py | 6 +- monai/transforms/io/array.py | 6 +- monai/transforms/io/dictionary.py | 6 +- monai/transforms/spatial/array.py | 120 ++++++++++---------- monai/transforms/spatial/dictionary.py | 54 ++++----- monai/utils/enums.py | 6 +- 13 files changed, 112 insertions(+), 112 deletions(-) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index a4daef70f6..310931d236 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -780,7 +780,7 @@ class RestoreLabeld(MapTransform): One of the listed string values or a user supplied function for padding. Defaults to ``"constant"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html align_corners: Geometrically, we consider the pixels of the input as squares rather than points. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of bool, each element corresponds to a key in ``keys``. meta_keys: explicitly indicate the key of the corresponding meta data dictionary. for example, for data with key `image`, the metadata by default is in `image_meta_dict`. diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index f31926cb6c..a5acdd032e 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -60,13 +60,13 @@ def __init__( mode: {``"bilinear"``, ``"nearest"``} This option is used when ``resample = True``. Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} This option is used when ``resample = True``. Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Geometrically, we consider the pixels of the input as squares rather than points. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. output_dtype: data type for saving data. Defaults to ``np.float32``. diff --git a/monai/data/nifti_writer.py b/monai/data/nifti_writer.py index 35044977e0..4e7a99f557 100644 --- a/monai/data/nifti_writer.py +++ b/monai/data/nifti_writer.py @@ -87,13 +87,13 @@ def write_nifti( mode: {``"bilinear"``, ``"nearest"``} This option is used when ``resample = True``. Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} This option is used when ``resample = True``. Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Geometrically, we consider the pixels of the input as squares rather than points. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. output_dtype: data type for saving data. Defaults to ``np.float32``. diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index 2e31597837..a83a560e9f 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -52,7 +52,7 @@ def __init__( resample: whether to resample and resize if providing spatial_shape in the metadata. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"nearest"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. data_root_dir: if not empty, it specifies the beginning parts of the input file's diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index 6f3b2ef86e..f1aa5fc5c8 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -39,7 +39,7 @@ def write_png( output_spatial_shape: spatial shape of the output image. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"bicubic"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling to [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index 79ebfd3a22..40bb5f8bed 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -72,16 +72,16 @@ def __init__( - NIfTI files {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - PNG files This option is ignored. diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 79fed14834..9fdaab0a48 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -38,7 +38,7 @@ def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePa - mode: ``"nearest"``, ``"bilinear"``, ``"bicubic"``. - padding_mode: ``"zeros"``, ``"border"``, ``"reflection"`` - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html For MONAI C++/CUDA extensions, the possible values are: diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index 01e45b2e67..7aa3e110fc 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -446,11 +446,11 @@ def __init__( coordinates. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"zeros"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample - align_corners: see also https://pytorch.org/docs/stable/nn.functional.html#grid-sample. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + align_corners: see also https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html. reverse_indexing: whether to reverse the spatial indexing of image and coordinates. set to `False` if `theta` follows pytorch's default "D, H, W" convention. set to `True` if `theta` follows `scipy.ndimage` default "i, j, k" convention. diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 19fafbcbf4..f8aa838439 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -248,16 +248,16 @@ class SaveImage(Transform): - NIfTI files {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - PNG files This option is ignored. diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 96850ac0cd..cc6a67593f 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -170,16 +170,16 @@ class SaveImaged(MapTransform): - NIfTI files {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - PNG files This option is ignored. diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index ee5d179b00..60e4f564c8 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -121,12 +121,12 @@ def __init__( of the original data. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Geometrically, we consider the pixels of the input as squares rather than points. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. @@ -157,12 +157,12 @@ def __call__( affine (matrix): (N+1)x(N+1) original affine matrix for spatially ND `data_array`. Defaults to identity. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Geometrically, we consider the pixels of the input as squares rather than points. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. @@ -389,10 +389,10 @@ class Resize(Transform): #albumentations.augmentations.geometric.resize.LongestMaxSize. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html """ backend = [TransformBackends.TORCH] @@ -420,10 +420,10 @@ def __call__( img: channel first array, must have shape: (num_channels, H[, W, ..., ]). mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html Raises: ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. @@ -469,12 +469,12 @@ class Rotate(Transform, ThreadUnsafe): input array is contained completely in the output. Default is True. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Defaults to False. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. @@ -512,14 +512,14 @@ def __call__( img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D]. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Defaults to ``self.align_corners``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Defaults to ``self.align_corners``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. @@ -578,7 +578,7 @@ def get_rotation_matrix(self) -> Optional[NdarrayOrTensor]: class Zoom(Transform): """ Zooms an ND image using :py:class:`torch.nn.functional.interpolate`. - For details, please see https://pytorch.org/docs/stable/nn.functional.html#interpolate. + For details, please see https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html. Different from :py:class:`monai.transforms.resize`, this transform takes scaling factors as input, and provides an option of preserving the input spatial size. @@ -589,7 +589,7 @@ class Zoom(Transform): If a sequence, zoom should contain one value for each spatial axis. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. @@ -599,7 +599,7 @@ class Zoom(Transform): https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html keep_size: Should keep original size (padding/slicing if needed), default is True. kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -636,7 +636,7 @@ def __call__( img: channel first array, must have shape: (num_channels, H[, W, ..., ]). mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. @@ -646,7 +646,7 @@ def __call__( https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html """ img_t: torch.Tensor @@ -778,12 +778,12 @@ class RandRotate(RandomizableTransform): If it is True, the output shape is the same as the input. Default is True. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Defaults to False. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. @@ -847,12 +847,12 @@ def __call__( img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D). mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html align_corners: Defaults to ``self.align_corners``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html dtype: data type for resampling computation. Defaults to ``self.dtype``. If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. @@ -965,7 +965,7 @@ class RandZoom(RandomizableTransform): If 2 values provided for 3D data, use the first value for both H & W dims to keep the same zoom ratio. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. @@ -975,7 +975,7 @@ class RandZoom(RandomizableTransform): https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html keep_size: Should keep original size (pad if needed), default is True. kwargs: other arguments for the `np.pad` or `torch.pad` function. note that `np.pad` treats channel dimension as the first dimension. @@ -1033,7 +1033,7 @@ def __call__( img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D). mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}. @@ -1043,7 +1043,7 @@ def __call__( https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html randomize: whether to execute `randomize()` function first, default to True. """ @@ -1348,10 +1348,10 @@ def __init__( Args: mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html device: device on which the tensor will be allocated. .. deprecated:: 0.6.0 @@ -1375,10 +1375,10 @@ def __call__( grid: shape must be (3, H, W) for 2D or (4, H, W, D) for 3D. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html """ if grid is None: raise ValueError("Unknown grid.") @@ -1484,10 +1484,10 @@ def __init__( to `(32, 64)` if the second spatial dimension size of img is `64`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html device: device on which the tensor will be allocated. image_only: if True return only the image volume, otherwise return (image, affine). @@ -1526,10 +1526,10 @@ def __call__( if `img` has three spatial dimensions, `spatial_size` should have 3 elements [h, w, d]. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) grid, affine = self.affine_grid(spatial_size=sp_size) @@ -1596,10 +1596,10 @@ def __init__( to `(32, 64)` if the second spatial dimension size of img is `64`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html cache_grid: whether to cache the identity sampling grid. If the spatial size is not dynamically defined by input image, enabling this option could accelerate the transform. @@ -1701,10 +1701,10 @@ def __call__( if `img` has three spatial dimensions, `spatial_size` should have 3 elements [h, w, d]. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html randomize: whether to execute `randomize()` function first, default to True. """ @@ -1786,10 +1786,10 @@ def __init__( to `(32, 64)` if the second spatial dimension size of img is `64`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html device: device on which the tensor will be allocated. See also: @@ -1849,10 +1849,10 @@ def __call__( the transform will use the spatial size of `img`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html randomize: whether to execute `randomize()` function first, default to True. """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) @@ -1942,10 +1942,10 @@ def __init__( to `(32, 32, 64)` if the third spatial dimension size of img is `64`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html device: device on which the tensor will be allocated. See also: @@ -2009,10 +2009,10 @@ def __call__( the transform will use the spatial size of `img`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``self.padding_mode``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html randomize: whether to execute `randomize()` function first, default to True. """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) @@ -2057,10 +2057,10 @@ def __init__( Each value in the tuple represents the distort step of the related cell. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html device: device on which the tensor will be allocated. """ @@ -2084,10 +2084,10 @@ def __call__( Each value in the tuple represents the distort step of the related cell. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html """ distort_steps = self.distort_steps if distort_steps is None else distort_steps @@ -2145,10 +2145,10 @@ def __init__( Defaults to (-0.03, 0.03). mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html device: device on which the tensor will be allocated. """ @@ -2184,10 +2184,10 @@ def __call__( img: shape must be (num_channels, H, W[, D]). mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html randomize: whether to shuffle the random factors using `randomize()`, default to True. """ if randomize: diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 0ab9210aaf..963015f420 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -183,14 +183,14 @@ def __init__( axes against the original ones. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. align_corners: Geometrically, we consider the pixels of the input as squares rather than points. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of bool, each element corresponds to a key in ``keys``. dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. To be compatible with other modules, @@ -526,11 +526,11 @@ class Resized(MapTransform, InvertibleTransform): #albumentations.augmentations.geometric.resize.LongestMaxSize. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html It also can be a sequence of string, each element corresponds to a key in ``keys``. align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. allow_missing_keys: don't raise exception if key is missing. """ @@ -639,11 +639,11 @@ def __init__( to `(32, 64)` if the second spatial dimension size of img is `64`. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -769,11 +769,11 @@ def __init__( This allows 0 to correspond to no change (i.e., a scaling of 1.0). mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. cache_grid: whether to cache the identity sampling grid. If the spatial size is not dynamically defined by input image, enabling this option could @@ -939,11 +939,11 @@ def __init__( This allows 0 to correspond to no change (i.e., a scaling of 1.0). mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -1075,11 +1075,11 @@ def __init__( This allows 0 to correspond to no change (i.e., a scaling of 1.0). mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -1311,14 +1311,14 @@ class Rotated(MapTransform, InvertibleTransform): If it is True, the output shape is the same as the input. Default is True. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. align_corners: Defaults to False. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of bool, each element corresponds to a key in ``keys``. dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. To be compatible with other modules, @@ -1422,14 +1422,14 @@ class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform): If it is True, the output shape is the same as the input. Default is True. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"border"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. align_corners: Defaults to False. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html It also can be a sequence of bool, each element corresponds to a key in ``keys``. dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. To be compatible with other modules, @@ -1548,7 +1548,7 @@ class Zoomd(MapTransform, InvertibleTransform): If a sequence, zoom should contain one value for each spatial axis. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html It also can be a sequence of string, each element corresponds to a key in ``keys``. padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} @@ -1559,7 +1559,7 @@ class Zoomd(MapTransform, InvertibleTransform): https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. keep_size: Should keep original size (pad if needed), default is True. allow_missing_keys: don't raise exception if key is missing. @@ -1648,7 +1648,7 @@ class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform): If 2 values provided for 3D data, use the first value for both H & W dims to keep the same zoom ratio. mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} The interpolation mode. Defaults to ``"area"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html It also can be a sequence of string, each element corresponds to a key in ``keys``. padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} @@ -1659,7 +1659,7 @@ class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform): https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html align_corners: This only has an effect when mode is 'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None. - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html It also can be a sequence of bool or None, each element corresponds to a key in ``keys``. keep_size: Should keep original size (pad if needed), default is True. allow_missing_keys: don't raise exception if key is missing. @@ -1779,11 +1779,11 @@ def __init__( Each value in the tuple represents the distort step of the related cell. mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. @@ -1829,11 +1829,11 @@ def __init__( Defaults to (-0.03, 0.03). mode: {``"bilinear"``, ``"nearest"``} Interpolation mode to calculate output values. Defaults to ``"bilinear"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} Padding mode for outside grid values. Defaults to ``"reflection"``. - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html It also can be a sequence of string, each element corresponds to a key in ``keys``. device: device on which the tensor will be allocated. allow_missing_keys: don't raise exception if key is missing. diff --git a/monai/utils/enums.py b/monai/utils/enums.py index e1fff184fc..93a6d8f49e 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -58,7 +58,7 @@ class NumpyPadMode(Enum): class GridSampleMode(Enum): """ - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html interpolation mode of `torch.nn.functional.grid_sample` @@ -76,7 +76,7 @@ class GridSampleMode(Enum): class InterpolateMode(Enum): """ - See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html """ NEAREST = "nearest" @@ -119,7 +119,7 @@ class PytorchPadMode(Enum): class GridSamplePadMode(Enum): """ - See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html """ ZEROS = "zeros" From aee76ee7ad625434ea601cb4acc3dd350bb57c16 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 26 Jan 2022 14:24:17 +0000 Subject: [PATCH 14/33] simplify norm_coords Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index b45888df8d..4ed5f4dd60 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1563,18 +1563,17 @@ def __call__( grid_t: torch.Tensor img_t, *_ = convert_data_type(img, torch.Tensor, device=_device, dtype=dtype) # type: ignore grid_t = convert_to_dst_type(grid, img_t)[0] # type: ignore - if grid_t is grid: # copy if needed - grid_t = grid_t.clone() + if grid_t is grid: # copy if needed (convert_data_type converts to contiguous) + grid_t = grid_t.clone(memory_format=torch.contiguous_format) sr = min(len(img_t.shape[1:]), 3) if USE_COMPILED: if self.norm_coords: - for i, dim in enumerate(img_t.shape[1:]): - grid_t[i] += (max(dim, 2) - 1.0) / 2.0 - grid_t = grid_t[:sr] / grid_t[-1:] + for i, dim in enumerate(img_t.shape[1 : 1 + sr]): + grid_t[i] = (max(dim, 2) / 2.0 - 0.5 + grid_t[i]) / grid_t[-1:] else: grid_t = grid_t[:sr] - grid_t = grid_t.permute(list(range(grid_t.ndimension()))[1:] + [0]) + grid_t = torch.movedim(grid_t, 0, -1) _padding_mode = look_up_option( self.padding_mode if padding_mode is None else padding_mode, GridSamplePadMode ).value @@ -1596,14 +1595,10 @@ def __call__( )[0] else: if self.norm_coords: - for i, dim in enumerate(img_t.shape[1:]): - grid_t[i] = 2.0 * grid_t[i] / (max(2, dim) - 1.0) - grid_t = grid_t[:sr] / grid_t[-1:] - else: - grid_t = grid_t[:sr] - index_ordering: List[int] = list(range(img_t.ndimension() - 2, -1, -1)) - grid_t = grid_t[index_ordering] - grid_t = grid_t.permute(list(range(grid_t.ndimension()))[1:] + [0]) + for i, dim in enumerate(img_t.shape[1 : 1 + sr]): + grid_t[i] = 2.0 / (max(2, dim) - 1.0) * grid_t[i] / grid_t[-1:] + index_ordering: List[int] = list(range(sr - 1, -1, -1)) + grid_t = torch.moveaxis(grid_t[index_ordering], 0, -1) out = torch.nn.functional.grid_sample( img_t.unsqueeze(0), grid_t.unsqueeze(0), From 0dc88af0b19112eb0e3e0f4e145422323aba23eb Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 26 Jan 2022 14:29:00 +0000 Subject: [PATCH 15/33] update docstring Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 4ed5f4dd60..2d278d139a 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -87,7 +87,11 @@ class SpatialResample(Transform): """ - Resample input image from the orientation/spacing defined by ``src`` affine into the ones specified by `dst`. + Resample input image from the orientation/spacing defined by ``src`` affine matrix into + the ones specified by ``dst`` affine matrix. + + Internally this transform computes the affine transform matrix from ``src`` to ``dst``, + by ``xform = np.linalg.inv(src) @ dst``, and call ``monai.transforms.Affine`` with ``xform``. """ backend = [TransformBackends.TORCH] From d044188262379a7e419f3fd6d77fc18cef4ebe2a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 26 Jan 2022 20:47:21 +0000 Subject: [PATCH 16/33] update moveaxis Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 2d278d139a..7ffbcde0bd 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -34,6 +34,7 @@ create_translate, map_spatial_axes, ) +from monai.transforms.utils_pytorch_numpy_unification import moveaxis from monai.utils import ( GridSampleMode, GridSamplePadMode, @@ -1575,9 +1576,7 @@ def __call__( if self.norm_coords: for i, dim in enumerate(img_t.shape[1 : 1 + sr]): grid_t[i] = (max(dim, 2) / 2.0 - 0.5 + grid_t[i]) / grid_t[-1:] - else: - grid_t = grid_t[:sr] - grid_t = torch.movedim(grid_t, 0, -1) + grid_t = moveaxis(grid_t[:sr], 0, -1) _padding_mode = look_up_option( self.padding_mode if padding_mode is None else padding_mode, GridSamplePadMode ).value @@ -1602,7 +1601,7 @@ def __call__( for i, dim in enumerate(img_t.shape[1 : 1 + sr]): grid_t[i] = 2.0 / (max(2, dim) - 1.0) * grid_t[i] / grid_t[-1:] index_ordering: List[int] = list(range(sr - 1, -1, -1)) - grid_t = torch.moveaxis(grid_t[index_ordering], 0, -1) + grid_t = moveaxis(grid_t[index_ordering], 0, -1) # type: ignore out = torch.nn.functional.grid_sample( img_t.unsqueeze(0), grid_t.unsqueeze(0), From 7a424529774f606c75f9458053651a2cb0e3d41d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 26 Jan 2022 22:53:17 +0000 Subject: [PATCH 17/33] spatial sample tests Signed-off-by: Wenqi Li --- monai/csrc/ext.cpp | 2 + monai/networks/layers/spatial_transforms.py | 4 +- monai/transforms/spatial/array.py | 35 +++---- tests/test_spatial_resample.py | 104 ++++++++++++++++++++ 4 files changed, 123 insertions(+), 22 deletions(-) create mode 100644 tests/test_spatial_resample.py diff --git a/monai/csrc/ext.cpp b/monai/csrc/ext.cpp index a2fa8bfc56..ac43e6fd3e 100644 --- a/monai/csrc/ext.cpp +++ b/monai/csrc/ext.cpp @@ -31,6 +31,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::enum_(m, "BoundType") .value("replicate", monai::BoundType::Replicate, "a a a | a b c d | d d d") .value("nearest", monai::BoundType::Replicate, "a a a | a b c d | d d d") + .value("border", monai::BoundType::Replicate, "a a a | a b c d | d d d") .value("dct1", monai::BoundType::DCT1, "d c b | a b c d | c b a") .value("mirror", monai::BoundType::DCT1, "d c b | a b c d | c b a") .value("dct2", monai::BoundType::DCT2, "c b a | a b c d | d c b") @@ -43,6 +44,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .value("wrap", monai::BoundType::DFT, "b c d | a b c d | a b c") // .value("sliding", monai::BoundType::Sliding) .value("zero", monai::BoundType::Zero, "0 0 0 | a b c d | 0 0 0") + .value("zeros", monai::BoundType::Zero, "0 0 0 | a b c d | 0 0 0") .export_values(); // resample interpolation mode diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index 7aa3e110fc..56f13736b4 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -70,13 +70,13 @@ def grid_pull( `bound` can be an int, a string or a BoundType. Possible values are:: - - 0 or 'replicate' or 'nearest' or BoundType.replicate + - 0 or 'replicate' or 'nearest' or BoundType.replicate or 'border' - 1 or 'dct1' or 'mirror' or BoundType.dct1 - 2 or 'dct2' or 'reflect' or BoundType.dct2 - 3 or 'dst1' or 'antimirror' or BoundType.dst1 - 4 or 'dst2' or 'antireflect' or BoundType.dst2 - 5 or 'dft' or 'wrap' or BoundType.dft - - 7 or 'zero' or BoundType.zero + - 7 or 'zero' or 'zeros' or BoundType.zero A list of values can be provided, in the order [W, H, D], to specify dimension-specific boundary conditions. diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 7ffbcde0bd..a29b347cb8 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -169,9 +169,9 @@ def __call__( When both ``monai.config.USE_COMPILED`` and ``align_corners`` are set to ``True``, MONAI's resampling implementation will be used. """ - spatial_rank = min(len(img.shape) - 1, 3) if src is None: src = np.eye(4, dtype=np.float64) + spatial_rank = min(len(img.shape) - 1, src.shape[0] - 1, 3) src = to_affine_nd(spatial_rank, src) dst = to_affine_nd(spatial_rank, dst) if dst is not None else src dst, *_ = convert_to_dst_type(dst, dst, dtype=torch.float32) @@ -211,8 +211,8 @@ def __call__( img_ = convert_data_type(img, torch.Tensor, dtype=_dtype)[0] # type: ignore xform = convert_to_dst_type(xform, img_)[0] # type: ignore align_corners = self.align_corners if align_corners is None else align_corners - mode = look_up_option(mode or self.mode, GridSampleMode) - padding_mode = look_up_option(padding_mode or self.padding_mode, GridSamplePadMode) + mode = mode or self.mode + padding_mode = padding_mode or self.padding_mode if additional_dims: xform_shape = [-1] + in_spatial_size img_ = img_.reshape(xform_shape) @@ -1576,25 +1576,20 @@ def __call__( if self.norm_coords: for i, dim in enumerate(img_t.shape[1 : 1 + sr]): grid_t[i] = (max(dim, 2) / 2.0 - 0.5 + grid_t[i]) / grid_t[-1:] - grid_t = moveaxis(grid_t[:sr], 0, -1) - _padding_mode = look_up_option( - self.padding_mode if padding_mode is None else padding_mode, GridSamplePadMode - ).value - if _padding_mode == "zeros": - bound = 7 - elif _padding_mode == "border": - bound = 0 + grid_t = moveaxis(grid_t[:sr], 0, -1) # type: ignore + _padding_mode = self.padding_mode if padding_mode is None else padding_mode + _padding_mode = _padding_mode.value if isinstance(_padding_mode, GridSamplePadMode) else _padding_mode + bound = 1 if _padding_mode == "reflection" else _padding_mode + _interp_mode = self.mode if mode is None else mode + _interp_mode = _interp_mode.value if isinstance(_interp_mode, GridSampleMode) else _interp_mode + if _interp_mode == "bicubic": + interp = 3 + elif _interp_mode == "bilinear": + interp = 1 else: - bound = 1 # "relection" - _interp_mode = look_up_option(self.mode if mode is None else mode, GridSampleMode).value - if _interp_mode == "nearest": - _interp = 0 - elif _interp_mode == "bicubic": - _interp = 3 - else: - _interp = 1 # "bilinear" + interp = _interp_mode # type: ignore out = grid_pull( - img_t.unsqueeze(0), grid_t.unsqueeze(0), bound=bound, extrapolate=True, interpolation=_interp + img_t.unsqueeze(0), grid_t.unsqueeze(0), bound=bound, extrapolate=True, interpolation=interp )[0] else: if self.norm_coords: diff --git a/tests/test_spatial_resample.py b/tests/test_spatial_resample.py new file mode 100644 index 0000000000..65abfee2e5 --- /dev/null +++ b/tests/test_spatial_resample.py @@ -0,0 +1,104 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.config import USE_COMPILED +from monai.transforms import SpatialResample +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] + +for ind, dst in enumerate( + [ + np.asarray([[1.0, 0.0, 0.0], [0.0, -1.0, 1.0], [0.0, 0.0, 1.0]]), # flip the second + np.asarray([[-1.0, 0.0, 1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), # flip the first + ] +): + for p in TEST_NDARRAYS: + for p_src in TEST_NDARRAYS: + for dtype in (np.float32, np.float64): + for align in (False, True): + for interp_mode in ("nearest", "bilinear"): + for padding_mode in ("zeros", "border", "reflection"): + TESTS.append( + [ + {}, # default no params + np.arange(4).reshape((1, 2, 2)) + 1.0, # data + { + "src": p_src(np.eye(3)), + "dst": p(dst) if dst is not None else None, + "dtype": dtype, + "align_corners": align, + "mode": interp_mode, + "padding_mode": padding_mode, + }, + np.array([[[2.0, 1.0], [4.0, 3.0]]]) + if ind == 0 + else np.array([[[3.0, 4.0], [1.0, 2.0]]]), + ] + ) + +for ind, dst in enumerate( + [ + np.asarray([[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]), + np.asarray([[-1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]), + ] +): + for p in TEST_NDARRAYS: + for p_src in TEST_NDARRAYS: + for dtype in (np.float32, np.float64): + for align in (True, False): + interp = ("nearest", "bilinear") + if align and USE_COMPILED: + interp = interp + (0, 1) # type: ignore + for interp_mode in interp: + for padding_mode in ("zeros", "border", "reflection"): + TESTS.append( + [ + {}, # default no params + np.arange(12).reshape((1, 2, 2, 3)) + 1.0, # data + { + "src": p_src(np.eye(4)), + "dst": p(dst) if dst is not None else None, + "dtype": dtype, + "align_corners": align, + "mode": interp_mode, + "padding_mode": padding_mode, + }, + np.array( + [[[[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]], [[10.0, 11.0, 12.0], [7.0, 8.0, 9.0]]]] + ) + if ind == 0 + else np.array( + [[[[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]] + ), + ] + ) + + +class TestSpatialResample(unittest.TestCase): + @parameterized.expand(TESTS) + def test_flips(self, init_param, img, data_param, expected_output): + for p in TEST_NDARRAYS: + _img = p(img) + _expected_output = p(expected_output) + output_data, output_dst = SpatialResample(**init_param)(img=_img, **data_param) + assert_allclose(output_data, _expected_output) + expected_dst = data_param.get("dst") if data_param.get("dst") is not None else data_param.get("src") + assert_allclose(output_dst, expected_dst, type_test=False) + + +if __name__ == "__main__": + unittest.main() From 5ec96a459c9b61b7a18f184be18c4ed546fc2def Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 27 Jan 2022 13:13:26 +0000 Subject: [PATCH 18/33] additional tests spatial resample Signed-off-by: Wenqi Li --- monai/data/utils.py | 10 ++++- monai/transforms/spatial/array.py | 5 ++- tests/test_spatial_resample.py | 67 ++++++++++++++++++++++++++----- 3 files changed, 68 insertions(+), 14 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 045eeaa79f..a42e53e99b 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -681,7 +681,10 @@ def compute_shape_offset( corners = np.asarray(np.meshgrid(*in_coords, indexing="ij")).reshape((len(shape), -1)) corners = np.concatenate((corners, np.ones_like(corners[:1]))) corners = in_affine @ corners - inv_mat = np.linalg.inv(out_affine) + try: + inv_mat = np.linalg.inv(out_affine) + except np.linalg.LinAlgError as e: + raise ValueError(f"Affine {out_affine} is not invertible") from e corners_out = inv_mat @ corners corners_out = corners_out[:-1] / corners_out[-1] out_shape = np.round(corners_out.ptp(axis=1) + 1.0) @@ -755,7 +758,10 @@ def reorient_spatial_axes( target_affine_, *_ = convert_data_type(target_affine, np.ndarray) start_ornt = nib.orientations.io_orientation(init_affine_) target_ornt = nib.orientations.io_orientation(target_affine_) - ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) + try: + ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) + except ValueError as e: + raise ValueError(f"The input affine {init_affine} and target affine {target_affine} are not compatible.") from e new_affine = init_affine_ @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) return ornt_transform, new_affine diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index a29b347cb8..84eae529d5 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -191,7 +191,10 @@ def __call__( output_data, *_ = convert_to_dst_type(img_, img, dtype=torch.float32) return output_data, dst - xform = np.linalg.inv(src) @ dst + try: + xform = np.linalg.inv(src) @ dst + except np.linalg.LinAlgError as e: + raise ValueError(f"src affine is not invertible: {src}") from e xform = to_affine_nd(spatial_rank, xform) # no resampling if it's identity transform if np.allclose(xform, np.diag(np.ones(len(xform))), atol=AFFINE_TOL): diff --git a/tests/test_spatial_resample.py b/tests/test_spatial_resample.py index 65abfee2e5..54393af8d5 100644 --- a/tests/test_spatial_resample.py +++ b/tests/test_spatial_resample.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import unittest import numpy as np @@ -38,7 +39,7 @@ np.arange(4).reshape((1, 2, 2)) + 1.0, # data { "src": p_src(np.eye(3)), - "dst": p(dst) if dst is not None else None, + "dst": p(dst), "dtype": dtype, "align_corners": align, "mode": interp_mode, @@ -71,7 +72,7 @@ np.arange(12).reshape((1, 2, 2, 3)) + 1.0, # data { "src": p_src(np.eye(4)), - "dst": p(dst) if dst is not None else None, + "dst": p(dst), "dtype": dtype, "align_corners": align, "mode": interp_mode, @@ -89,15 +90,59 @@ class TestSpatialResample(unittest.TestCase): - @parameterized.expand(TESTS) - def test_flips(self, init_param, img, data_param, expected_output): - for p in TEST_NDARRAYS: - _img = p(img) - _expected_output = p(expected_output) - output_data, output_dst = SpatialResample(**init_param)(img=_img, **data_param) - assert_allclose(output_data, _expected_output) - expected_dst = data_param.get("dst") if data_param.get("dst") is not None else data_param.get("src") - assert_allclose(output_dst, expected_dst, type_test=False) + @parameterized.expand(itertools.product(TEST_NDARRAYS, TESTS)) + def test_flips(self, p_type, args): + init_param, img, data_param, expected_output = args + _img = p(img) + _expected_output = p(expected_output) + output_data, output_dst = SpatialResample(**init_param)(img=_img, **data_param) + assert_allclose(output_data, _expected_output) + expected_dst = data_param.get("dst") if data_param.get("dst") is not None else data_param.get("src") + assert_allclose(output_dst, expected_dst, type_test=False) + + @parameterized.expand(itertools.product([True, False], TEST_NDARRAYS)) + def test_4d_5d(self, is_5d, p_type): + new_shape = (1, 2, 2, 3, 1, 1) if is_5d else (1, 2, 2, 3, 1) + img = np.arange(12).reshape(new_shape) + img = np.tile(img, (1, 1, 1, 1, 2, 2) if is_5d else (1, 1, 1, 1, 2)) + _img = p_type(img) + dst = np.asarray([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.5], [0.0, 0.0, 0.0, 1.0]]) + output_data, output_dst = SpatialResample(dtype=np.float32)(img=_img, src=p(np.eye(4)), dst=dst) + expected_data = ( + np.asarray( + [ + [ + [[[0.0, 0.0], [0.0, 1.0]], [[0.5, 0.0], [1.5, 1.0]], [[1.0, 2.0], [2.0, 2.0]]], + [[[3.0, 3.0], [3.0, 4.0]], [[3.5, 3.0], [4.5, 4.0]], [[4.0, 5.0], [5.0, 5.0]]], + ], + [ + [[[6.0, 6.0], [6.0, 7.0]], [[6.5, 6.0], [7.5, 7.0]], [[7.0, 8.0], [8.0, 8.0]]], + [[[9.0, 9.0], [9.0, 10.0]], [[9.5, 9.0], [10.5, 10.0]], [[10.0, 11.0], [11.0, 11.0]]], + ], + ], + dtype=np.float32, + ) + if is_5d + else np.asarray( + [ + [[[0.5, 0.0], [0.0, 2.0], [1.5, 1.0]], [[3.5, 3.0], [3.0, 5.0], [4.5, 4.0]]], + [[[6.5, 6.0], [6.0, 8.0], [7.5, 7.0]], [[9.5, 9.0], [9.0, 11.0], [10.5, 10.0]]], + ], + dtype=np.float32, + ) + ) + assert_allclose(output_data, p_type(expected_data[None])) + assert_allclose(output_dst, dst, type_test=False) + + def test_ill_affine(self): + img = np.arange(12).reshape(1, 2, 2, 3) + ill_affine = np.asarray( + [[1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.5], [0.0, 0.0, 0.0, 1.0]] + ) + with self.assertRaises(ValueError): + SpatialResample()(img=img, src=np.eye(4), dst=ill_affine) + with self.assertRaises(ValueError): + SpatialResample()(img=img, src=ill_affine, dst=np.eye(3)) if __name__ == "__main__": From 6509070099bfd7387bc149df2468e6a0ab67f27b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 27 Jan 2022 13:51:51 +0000 Subject: [PATCH 19/33] test invert saptial resample Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 2 + monai/transforms/spatial/dictionary.py | 23 +++-- tests/test_spatial_resampled.py | 117 +++++++++++++++++++++++++ 3 files changed, 134 insertions(+), 8 deletions(-) create mode 100644 tests/test_spatial_resampled.py diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 84eae529d5..a5a00944ca 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -166,6 +166,8 @@ def __call__( If ``None``, use the data type of input data. To be compatible with other modules, the output data type is always `float32`. + The spatial rank is determined by the smallest among ``img.ndim -1``, ``len(src) - 1``, and ``3``. + When both ``monai.config.USE_COMPILED`` and ``align_corners`` are set to ``True``, MONAI's resampling implementation will be used. """ diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index d9fca9b9bd..4f73c5c915 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -140,10 +140,13 @@ class SpatialResampled(MapTransform, InvertibleTransform): Dictionary-based wrapper of :py:class:`monai.transforms.SpatialResample`. This transform assumes the ``data`` dictionary has a key for the input - data's metadata and contains ``src`` and ``dst`` affine required by `SpatialResample`. - The key is formed by ``key_{meta_key_postfix}``. + data's metadata and contains ``src`` and ``dst`` affine required by + `SpatialResample`. The key is formed by ``key_{meta_key_postfix}``. The + transform will swap ``src`` and ``dst`` affine (with potential data type + changes) in the dictionary so that ``src`` always refers to the current + status of affine. - see also: + See also: :py:class:`monai.transforms.SpatialResample` """ @@ -238,6 +241,7 @@ def __call__( align_corners=align_corners, dtype=dtype, ) + meta_data[meta_dst_key], meta_data[meta_src_key] = meta_data[meta_src_key], meta_data[meta_dst_key] self.push_transform( d, key, @@ -259,24 +263,27 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd transform = self.get_most_recent_transform(d, key) # Create inverse transform meta_data = d[transform[TraceKeys.EXTRA_INFO]["meta_key"]] - src_affine = meta_data[d[transform[TraceKeys.EXTRA_INFO]["meta_src_key"]]] # type: ignore - dst_affine = meta_data[d[transform[TraceKeys.EXTRA_INFO]["meta_dst_key"]]] # type: ignore + src_key = transform[TraceKeys.EXTRA_INFO]["meta_src_key"] + dst_key = transform[TraceKeys.EXTRA_INFO]["meta_dst_key"] + src_affine = meta_data[src_key] # type: ignore + dst_affine = meta_data[dst_key] # type: ignore mode = transform[TraceKeys.EXTRA_INFO]["mode"] padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] orig_size = transform[TraceKeys.ORIG_SIZE] inverse_transform = SpatialResample() # Apply inverse - d[key], _ = inverse_transform( + d[key], dst_affine = inverse_transform( img=d[key], - src=dst_affine, - dst=src_affine, + src=src_affine, + dst=dst_affine, mode=mode, padding_mode=padding_mode, align_corners=False if align_corners == TraceKeys.NONE else align_corners, dtype=dtype, spatial_size=orig_size, ) + meta_data[src_key], meta_data[dst_key] = dst_affine, meta_data[src_key] # type: ignore # Remove the applied transform self.pop_transform(d, key) return d diff --git a/tests/test_spatial_resampled.py b/tests/test_spatial_resampled.py new file mode 100644 index 0000000000..2b6ca17b84 --- /dev/null +++ b/tests/test_spatial_resampled.py @@ -0,0 +1,117 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.config import USE_COMPILED +from monai.transforms import SpatialResampleD +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] + +for ind, dst in enumerate( + [ + np.asarray([[1.0, 0.0, 0.0], [0.0, -1.0, 1.0], [0.0, 0.0, 1.0]]), # flip the second + np.asarray([[-1.0, 0.0, 1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), # flip the first + ] +): + for p in TEST_NDARRAYS: + for p_src in TEST_NDARRAYS: + for dtype in (np.float32, np.float64): + for align in (False, True): + for interp_mode in ("nearest", "bilinear"): + for padding_mode in ("zeros", "border", "reflection"): + TESTS.append( + [ + {}, # default no params + np.arange(4).reshape((1, 2, 2)) + 1.0, # data + { + "src": p_src(np.eye(3)), + "dst": p(dst), + "dtype": dtype, + "align_corners": align, + "mode": interp_mode, + "padding_mode": padding_mode, + }, + np.array([[[2.0, 1.0], [4.0, 3.0]]]) + if ind == 0 + else np.array([[[3.0, 4.0], [1.0, 2.0]]]), + ] + ) + +for ind, dst in enumerate( + [ + np.asarray([[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]), + np.asarray([[-1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]), + ] +): + for p in TEST_NDARRAYS: + for p_src in TEST_NDARRAYS: + for dtype in (np.float32, np.float64): + for align in (True, False): + interp = ("nearest", "bilinear") + if align and USE_COMPILED: + interp = interp + (0, 1) # type: ignore + for interp_mode in interp: + for padding_mode in ("zeros", "border", "reflection"): + TESTS.append( + [ + {}, # default no params + np.arange(12).reshape((1, 2, 2, 3)) + 1.0, # data + { + "src": p_src(np.eye(4)), + "dst": p(dst), + "dtype": dtype, + "align_corners": align, + "mode": interp_mode, + "padding_mode": padding_mode, + }, + np.array( + [[[[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]], [[10.0, 11.0, 12.0], [7.0, 8.0, 9.0]]]] + ) + if ind == 0 + else np.array( + [[[[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]] + ), + ] + ) + + +class TestSpatialResample(unittest.TestCase): + @parameterized.expand(itertools.product(TEST_NDARRAYS, TESTS)) + def test_flips_inverse(self, p_type, args): + init_param, img, data_param, expected_output = args + _img = p(img) + _expected_output = p(expected_output) + input_dict = {"img": _img, "img_meta_dict": {"src": data_param.get("src"), "dst": data_param.get("dst")}} + xform = SpatialResampleD( + keys="img", + meta_src_keys="src", + meta_dst_keys="dst", + mode=data_param.get("mode"), + padding_mode=data_param.get("padding_mode"), + ) + output_data = xform(input_dict) + assert_allclose(output_data["img"], _expected_output) + assert_allclose(output_data["img_meta_dict"]["src"], data_param.get("dst"), type_test=False) + + inverted = xform.inverse(output_data) + self.assertEqual(inverted["img_transforms"], []) # no further invert after inverting + assert_allclose(inverted["img_meta_dict"]["src"], data_param.get("src"), type_test=False) + assert_allclose(inverted["img"], _img) + + +if __name__ == "__main__": + unittest.main() From 0705aa74397208e48bfe54c3661c5bbeb6d4bbbf Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 17 Jan 2022 12:49:04 +0000 Subject: [PATCH 20/33] adds a base writer and an itk writer Signed-off-by: Wenqi Li --- docs/source/data.rst | 13 ++ monai/data/__init__.py | 1 + monai/data/image_writer.py | 399 +++++++++++++++++++++++++++++++++++++ monai/data/utils.py | 52 +++-- tests/test_itk_writer.py | 69 +++++++ 5 files changed, 512 insertions(+), 22 deletions(-) create mode 100644 monai/data/image_writer.py create mode 100644 tests/test_itk_writer.py diff --git a/docs/source/data.rst b/docs/source/data.rst index 1e6b535b12..1bb9bc8869 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -150,6 +150,19 @@ WSIReader .. autoclass:: WSIReader :members: +Image writer +------------ + +ImageWriter +~~~~~~~~~~~ +.. autoclass:: ImageWriter + :members: + +ITKWriter +~~~~~~~~~ +.. autoclass:: ITKWriter + :members: + Nifti format handling --------------------- diff --git a/monai/data/__init__.py b/monai/data/__init__.py index bd49f40273..78358ff030 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -35,6 +35,7 @@ from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader +from .image_writer import ImageWriter, ITKWriter from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py new file mode 100644 index 0000000000..7a5f88a7fb --- /dev/null +++ b/monai/data/image_writer.py @@ -0,0 +1,399 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Mapping, Optional, Sequence, Union + +import numpy as np +import torch + +from monai.apps.utils import get_logger +from monai.config import DtypeLike, NdarrayOrTensor, PathLike +from monai.data.utils import compute_shape_offset, ensure_mat44, to_affine_nd +from monai.networks.layers import AffineTransform +from monai.transforms.utils_pytorch_numpy_unification import ascontiguousarray, moveaxis +from monai.utils import GridSampleMode, GridSamplePadMode, convert_data_type, optional_import, require_pkg + +AFFINE_TOL = 1e-3 +DEFAULT_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d - %(message)s" +logger = get_logger(module_name=__name__, fmt=DEFAULT_FMT) + +if TYPE_CHECKING: + import itk # type: ignore + import nibabel as nib + + has_itk = has_nib = has_pil = True +else: + itk, has_itk = optional_import("itk", allow_namespace_pkg=True) + nib, has_nib = optional_import("nibabel") + +__all__ = ["ImageWriter", "ITKWriter"] + + +class ImageWriter: + """ + The class is a collection of utilities to write images to disk. + + Main aspects to be considered are: + + - dimensionality of the data array, arrangements of spatial dimensions and channel/time dimensions + - ``convert_to_channel_last()`` + - metadata of current affine and output affine, the data array should be converted accordingly + - ``get_meta_info()`` + - ``resample_if_needed()`` + - data type of the output image + - as part of ``resample_if_needed()`` + + Subclasses of this class should implement the backend-specific functions: + + - ``set_data_array()`` to set the data array (numpy array or torch tensor) + - this calls the base class utilities for array handling + - ``set_metadata()`` to set the metadata and output affine + - this calls the base class utilities for metadata (affine, resmaple) handling + - backend-specific data object + - ``create_backend_obj()`` + - backend-specific writing function + - ``write()`` + + The primary usage is: + + .. code-block:: python + + writer = MyWriter() # subclass of ImageWriter + writer.set_data_array(data_array) + writer.set_metadata(meta_dict).write(filename_or_obj) + writer.write(filename) + + Create an image writer object based on ``data_array`` and ``metadata``. + Spatially it supports up to three dimensions (with the resampling step + supports both 2D and 3D). + + When saving multiple time steps or multiple channels `data_array`, time + and/or modality axes should be the at the `channel_dim`. For example, + the shape of a 2D eight-class and channel_dim=0, the segmentation + probabilities to be saved could be `(8, 64, 64)`; in this case + ``data_array`` will be converted to `(64, 64, 1, 8)` (the third + dimension is reserved as a spatial dimension). + + The ``metadata`` could optionally have the following keys: + + - ``'original_affine'``: for data original affine, it will be the + affine of the output object, defaulting to an identity matrix. + - ``'affine'``: it should specify the current data affine, defaulting to an identity matrix. + - ``'spatial_shape'``: for data output spatial shape. + + When ``metadata`` is specified and ``resample=True``, the saver will + try to resample data from the space defined by `"affine"` to the space + defined by `"original_affine"`, for more details, please refer to the + ``convert_to_target_affine`` method. + """ + + def __init__(self, **kwargs): + self.data_obj = None + for k, v in kwargs.items(): + setattr(self, k, v) + + def set_data_array(self, data_array, **kwargs): + raise NotImplementedError(f"Subclasses of {self.__class__.__name__} must implement this method.") + + def set_metadata(self, meta_dict: Optional[Mapping], **options): + raise NotImplementedError(f"Subclasses of {self.__class__.__name__} must implement this method.") + + def write(self, filename_or_obj: PathLike, verbose: bool = True, **kwargs): + """subclass should implement this method to call the backend-specific writing APIs.""" + if verbose: + logger.info(f"writing: {filename_or_obj}") + + @classmethod + def create_backend_obj(cls, data_array: NdarrayOrTensor, **kwargs) -> np.ndarray: + """ + Subclass should implement this method to return a backend-specific data representation object. + This method is used by ``cls.create_data_obj`` and the input ``data_array`` is 'channel-last'. + """ + return convert_data_type(data_array, np.ndarray)[0] # type: ignore + + @classmethod + def resample_if_needed( + cls, + data_array: NdarrayOrTensor, + affine: Optional[NdarrayOrTensor] = None, + target_affine: Optional[np.ndarray] = None, + output_spatial_shape: Union[Sequence[int], np.ndarray, None] = None, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: DtypeLike = np.float64, + ): + """ + Convert the ``data_array`` into the coordinate system specified by + ``target_affine``, from the current coordinate definition of ``affine``. + + If the transform between ``affine`` and ``target_affine`` could be + achieved by simply transposing and flipping ``data_array``, no resampling + will happen. Otherwise, this function resamples ``data_array`` using the + transformation computed from ``affine`` and ``target_affine``. + + This function assumes the NIfTI dimension notations. Spatially it + supports up to three dimensions, that is, H, HW, HWD for 1D, 2D, 3D + respectively. When saving multiple time steps or multiple channels, + time and/or modality axes should be appended after the first three + dimensions. For example, shape of 2D eight-class segmentation + probabilities to be saved could be `(64, 64, 1, 8)`. Also, data in + shape `(64, 64, 8)` or `(64, 64, 8, 1)` will be considered as a + single-channel 3D image. The ``convert_to_channel_last`` method can be + used to convert the data to the format described here. + + Note that the shape of the resampled ``data_array`` may subject to some + rounding errors. For example, resampling a 20x20 pixel image from pixel + size (1.5, 1.5)-mm to (3.0, 3.0)-mm space will return a 10x10-pixel + image. However, resampling a 20x20-pixel image from pixel size (2.0, + 2.0)-mm to (3.0, 3.0)-mm space will output a 14x14-pixel image, where + the image shape is rounded from 13.333x13.333 pixels. In this case + ``output_spatial_shape`` could be specified so that this function + writes image data to a designated shape. + + Args: + data_array: input data array to be converted. + affine: the current affine of ``data_array``. Defaults to identity + target_affine: the designated affine of ``data_array``. + The actual output affine might be different from this value due to precision changes. + output_spatial_shape: spatial shape of the output image. + This option is used when resampling is needed. + mode: available options are {``"bilinear"``, ``"nearest"``, ``"bicubic"``}. + This option is used when resampling is needed. + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + padding_mode: available options are {``"zeros"``, ``"border"``, ``"reflection"``}. + This option is used when resampling is needed. + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + align_corners: boolean option of ``grid_sample`` to handle the corner convention. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + dtype: data type for resampling computation. Defaults to + ``np.float64`` for best precision. If ``None``, use the data type of input data. + """ + data: np.ndarray + data, *_ = convert_data_type(data_array, np.ndarray) # type: ignore + + sr = min(data.ndim, 3) + if affine is None: + affine = np.eye(4, dtype=np.float64) + affine = to_affine_nd(sr, affine) # type: ignore + target_affine = to_affine_nd(sr, target_affine) if target_affine is not None else affine + + if np.allclose(affine, target_affine, atol=AFFINE_TOL): + # no affine changes, return (data, affine) + return data, ensure_mat44(target_affine) + + # resolve orientation + if has_nib: # this is to avoid dependency on nibabel + start_ornt = nib.orientations.io_orientation(affine) + target_ornt = nib.orientations.io_orientation(target_affine) + ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) + data_shape = data.shape + data = nib.orientations.apply_orientation(data, ornt_transform) + _affine = affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) + if np.allclose(_affine, target_affine, atol=AFFINE_TOL): + return data, ensure_mat44(_affine) + + # need resampling + dtype = dtype or data.dtype # type: ignore + if output_spatial_shape is None: + output_spatial_shape, _ = compute_shape_offset(data.shape, _affine, target_affine) + output_spatial_shape_ = list(output_spatial_shape) if output_spatial_shape is not None else [] + sp_dims = min(data.ndim, 3) + output_spatial_shape_ += [1] * (sp_dims - len(output_spatial_shape_)) + output_spatial_shape_ = output_spatial_shape_[:sp_dims] + original_channels = data.shape[3:] + if original_channels: # multi channel, resampling each channel + data_np: np.ndarray = data.reshape(list(data.shape[:3]) + [-1]) # type: ignore + data_np = np.moveaxis(data_np, -1, 0) # channel first for pytorch + else: # single channel image, need to expand to have a channel + data_np = data[None] + affine_xform = AffineTransform( + normalized=False, mode=mode, padding_mode=padding_mode, align_corners=align_corners, reverse_indexing=True + ) + data_torch = affine_xform( + torch.as_tensor(np.ascontiguousarray(data_np, dtype=dtype)).unsqueeze(0), + torch.as_tensor(np.ascontiguousarray(np.linalg.inv(_affine) @ target_affine, dtype=dtype)), + spatial_size=output_spatial_shape_, + ) + data_np = data_torch[0].detach().cpu().numpy() + if original_channels: + data_np = np.moveaxis(data_np, 0, -1) # channel last + data_np = data_np.reshape(list(data_np.shape[:3]) + list(original_channels)) + else: + data_np = data_np[0] + return data_np, ensure_mat44(target_affine) + + @classmethod + def convert_to_channel_last( + cls, + data: NdarrayOrTensor, + channel_dim: Optional[int] = 0, + squeeze_end_dims: bool = True, + spatial_ndim: Optional[int] = 3, + contiguous: bool = False, + ): + """ + Rearrange the data array axes to make the `channel_dim`-th dim the last + dimension and ensure there are three spatial dimensions. If + ``channel_dim`` is ``None``, a new axis will be appended to the last + dimension. + + When ``squeeze_end_dims`` is ``True``, a postprocessing step will be + applied to remove any trailing singleton dimensions. + + Args: + data: input data to be converted to "channel-last" format. + channel_dim: specifies the axis of the data array that is the channel dimension. + ``None`` indicates no channel dimension. + squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed (after the channel + has been moved to the end). So if input is `(H,W,D,C)` and C==1, then it will be saved as `(H,W,D)`. + If D is also 1, it will be saved as `(H,W)`. If ``False``, image will always be saved as `(H,W,D,C)`. + spatial_ndim: modifying the spatial dims if needed, so that output to have at least + this number of spatial dims. If ``None``, the output will have the same number of + spatial dimensions as the input. + contiguous: if ``True``, the output will be contiguous. + """ + # change data to "channel last" format + if channel_dim is not None: + # _chns = ensure_tuple_rep(channel_dim) + # data = moveaxis(data, _chns, tuple(range(-len(_chns), 0))) + data = moveaxis(data, channel_dim, -1) + else: # adds a channel dimension + data = data[..., None] + # To ensure at least three spatial dims + if spatial_ndim: + while len(data.shape) < spatial_ndim + 1: # assuming the data has spatial + channel dims + data = data[..., None, :] + while len(data.shape) > spatial_ndim + 1: + data = data[..., 0, :] + # if desired, remove trailing singleton dimensions + while squeeze_end_dims and data.shape[-1] == 1: + data = np.squeeze(data, -1) + if contiguous: + data = ascontiguousarray(data) + return data + + @classmethod + def get_meta_info(cls, metadata: Optional[Mapping] = None): + """ + Extracts relevant meta information from the metadata object (using ``.get``). + + This method returns the following fields (the default value is ``None``): + + - ``'original_affine'``: for data original affine (before any image processing), + - ``'affine'``: for the current data affine (representing the current coordinate information), + - ``'spatial_shape'``: for data original spatial shape. + """ + if not metadata: + default_dict = {"original_affine": None, "affine": None, "spatial_shape": None} + metadata = default_dict + original_affine = metadata.get("original_affine") + affine = metadata.get("affine") + spatial_shape = metadata.get("spatial_shape") + return original_affine, affine, spatial_shape + + +@require_pkg(pkg_name="itk") +class ITKWriter(ImageWriter): + """ + Write data and metadata into files on disk using ITK-python. + """ + + def __init__(self, output_dtype: DtypeLike = np.float32, **kwargs): + super().__init__(output_dtype=output_dtype, affine=None, channel_dim=0, **kwargs) + + def set_data_array(self, data_array, channel_dim: Optional[int] = 0, squeeze_end_dims: bool = True, **kwargs): + self.data_obj = self.convert_to_channel_last( + data=data_array, + channel_dim=channel_dim, + squeeze_end_dims=squeeze_end_dims, + spatial_ndim=kwargs.pop("spatial_ndim", 3), + contiguous=kwargs.pop("contiguous", True), + ) + self.channel_dim = channel_dim + return self + + def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = True, **options): + original_affine, affine, spatial_shape = self.get_meta_info(meta_dict) + self.data_obj, self.affine = self.resample_if_needed( + data_array=self.data_obj, + affine=affine, + target_affine=original_affine if resample else None, + output_spatial_shape=spatial_shape, + mode=options.pop("mode", GridSampleMode.BILINEAR), + padding_mode=options.pop("padding_mode", GridSamplePadMode.BORDER), + align_corners=options.pop("align_corners", False), + dtype=options.pop("dtype", np.float64), + ) + return self + + def write(self, filename_or_obj: PathLike, verbose: bool = False, **kwargs): + super().write(filename_or_obj, verbose=verbose) + self.data_obj = self.create_backend_obj( + self.data_obj, channel_dim=self.channel_dim, affine=self.affine, dtype=self.output_dtype, **kwargs # type: ignore + ) + itk.imwrite( + self.data_obj, + filename_or_obj, + compression=kwargs.pop("compression", False), + imageio=kwargs.pop("imageio", None), + ) + + @classmethod + def create_backend_obj( + cls, + data_array: NdarrayOrTensor, + channel_dim: Optional[int] = 0, + affine: Optional[NdarrayOrTensor] = None, + dtype: DtypeLike = np.float32, + **kwargs, + ): + """create an ITK object from ``data_array``. This method assumes a 'channel-last' ``data_array``.""" + + data_array = super().create_backend_obj(data_array) + _is_vec = channel_dim is not None + if _is_vec: + data_array = np.moveaxis(data_array, -1, 0) # from channel last to channel first + data_array = data_array.T.astype(dtype, copy=True, order="C") + itk_obj = itk.GetImageFromArray(data_array, is_vector=_is_vec, ttype=kwargs.pop("ttype", None)) + + d = len(itk.size(itk_obj)) + # convert affine to LPS + if affine is None: + affine = np.eye(d + 1, dtype=np.float64) + _affine: np.ndarray = convert_data_type(affine, np.ndarray)[0] # type: ignore + _affine = cls.ras_to_lps(to_affine_nd(d, _affine)) + spacing = np.sqrt(np.sum(np.square(_affine[:d, :d]), 0)) + spacing[spacing == 0] = 1.0 + _direction: np.ndarray = np.diag(1 / spacing) + _direction = _affine[:d, :d] @ _direction + itk_obj.SetSpacing(spacing.tolist()) + itk_obj.SetOrigin(_affine[:d, -1].tolist()) + itk_obj.SetDirection(itk.GetMatrixFromArray(_direction)) + return itk_obj + + @staticmethod + def ras_to_lps(affine: NdarrayOrTensor): + """ + Convert the ``affine`` from `RAS` to `LPS` by flipping the first two spatial dimensions. + (This could also be used to convert from `LPS` to `RAS`.) + + Args: + affine: a 2D affine matrix. + """ + sr = max(affine.shape[0] - 1, 1) # spatial rank is at least 1 + flip_d = [[-1, 1], [-1, -1, 1], [-1, -1, 1, 1]] + flip_diag = flip_d[min(sr - 1, 2)] + [1] * (sr - 3) + if isinstance(affine, torch.Tensor): + return torch.diag(torch.as_tensor(flip_diag).to(affine)) @ affine + return np.diag(flip_diag).astype(affine.dtype) @ affine diff --git a/monai/data/utils.py b/monai/data/utils.py index 79ef9bd7fb..9fb609d2cc 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -32,7 +32,9 @@ from monai.utils import ( MAX_SEED, BlendMode, + Method, NumpyPadMode, + convert_data_type, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, @@ -42,7 +44,6 @@ look_up_option, optional_import, ) -from monai.utils.enums import Method pd, _ = optional_import("pandas") DataFrame, _ = optional_import("pandas", name="DataFrame") @@ -679,56 +680,54 @@ def compute_shape_offset( corners = np.asarray(np.meshgrid(*in_coords, indexing="ij")).reshape((len(shape), -1)) corners = np.concatenate((corners, np.ones_like(corners[:1]))) corners = in_affine @ corners - corners_out = np.linalg.inv(out_affine) @ corners + inv_mat = np.linalg.inv(out_affine) + corners_out = inv_mat @ corners corners_out = corners_out[:-1] / corners_out[-1] out_shape = np.round(corners_out.ptp(axis=1) + 1.0) - if np.allclose(nib.io_orientation(in_affine), nib.io_orientation(out_affine)): - # same orientation, get translate from the origin - offset = in_affine @ ([0] * sr + [1]) - offset = offset[:-1] / offset[-1] - else: - # different orientation, the min is the origin - corners = corners[:-1] / corners[-1] - offset = np.min(corners, 1) + mat = inv_mat[:-1, :-1] + k = 0 + for i in range(corners.shape[1]): + min_corner = np.min(mat @ corners[:-1, :] - mat @ corners[:-1, i : i + 1], 1) + if np.allclose(min_corner, 0.0): + k = i + break + offset = corners[:-1, k] return out_shape.astype(int, copy=False), offset -def to_affine_nd(r: Union[np.ndarray, int], affine: np.ndarray) -> np.ndarray: +def to_affine_nd(r: Union[np.ndarray, int], affine, dtype=np.float64) -> np.ndarray: """ Using elements from affine, to create a new affine matrix by assigning the rotation/zoom/scaling matrix and the translation vector. - when ``r`` is an integer, output is an (r+1)x(r+1) matrix, where the top left kxk elements are copied from ``affine``, the last column of the output affine is copied from ``affine``'s last column. `k` is determined by `min(r, len(affine) - 1)`. - - when ``r`` is an affine matrix, the output has the same as ``r``, - the top left kxk elements are copied from ``affine``, + when ``r`` is an affine matrix, the output has the same shape as ``r``, + and the top left kxk elements are copied from ``affine``, the last column of the output affine is copied from ``affine``'s last column. `k` is determined by `min(len(r) - 1, len(affine) - 1)`. - Args: r (int or matrix): number of spatial dimensions or an output affine to be filled. affine (matrix): 2D affine matrix - + dtype: data type of the output array. Raises: ValueError: When ``affine`` dimensions is not 2. ValueError: When ``r`` is nonpositive. - Returns: an (r+1) x (r+1) matrix - """ - affine_np = np.array(affine, dtype=np.float64) + affine_np: np.ndarray + affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True)[0] # type: ignore + affine_np = affine_np.copy() if affine_np.ndim != 2: raise ValueError(f"affine must have 2 dimensions, got {affine_np.ndim}.") - new_affine = np.array(r, dtype=np.float64, copy=True) + new_affine = np.array(r, dtype=dtype, copy=True) if new_affine.ndim == 0: sr: int = int(new_affine.astype(np.uint)) if not np.isfinite(sr) or sr < 0: raise ValueError(f"r must be positive, got {sr}.") - new_affine = np.eye(sr + 1, dtype=np.float64) + new_affine = np.eye(sr + 1, dtype=dtype) d = max(min(len(new_affine) - 1, len(affine_np) - 1), 1) new_affine[:d, :d] = affine_np[:d, :d] if d > 1: @@ -736,6 +735,15 @@ def to_affine_nd(r: Union[np.ndarray, int], affine: np.ndarray) -> np.ndarray: return new_affine +def ensure_mat44(affine, dtype=np.float64) -> np.ndarray: + """ + Given a matrix `affine`, ensure that it is a float64 4x4 matrix using `to_affine_nd`. + """ + if affine is None: + return np.eye(4, dtype=dtype) + return to_affine_nd(r=3, affine=affine, dtype=dtype) + + def create_file_basename( postfix: str, input_file_name: PathLike, diff --git a/tests/test_itk_writer.py b/tests/test_itk_writer.py new file mode 100644 index 0000000000..3ce696f0ac --- /dev/null +++ b/tests/test_itk_writer.py @@ -0,0 +1,69 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import ITKWriter +from monai.utils import optional_import +from tests.utils import TEST_NDARRAYS, assert_allclose + +itk, has_itk = optional_import("itk") +nib, has_nibabel = optional_import("nibabel") + +TEST_CASES_AFFINE = [] +for p in TEST_NDARRAYS: + case_1d = p([[1.0, 0.0], [1.0, 1.0]]), p([[-1.0, 0.0], [1.0, 1.0]]) + TEST_CASES_AFFINE.append(case_1d) + case_2d_1 = p([[1.0, 0.0, 1.0], [1.0, 1.0, 1.0]]), p([[-1.0, 0.0, -1.0], [1.0, 1.0, 1.0]]) + TEST_CASES_AFFINE.append(case_2d_1) + case_2d_2 = p([[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]), p( + [[-1.0, 0.0, -1.0], [0.0, -1.0, -1.0], [1.0, 1.0, 1.0]] + ) + TEST_CASES_AFFINE.append(case_2d_2) + case_3d = p([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 2.0], [1.0, 1.0, 1.0, 3.0]]), p( + [[-1.0, 0.0, -1.0, -1.0], [0.0, -1.0, -1.0, -2.0], [1.0, 1.0, 1.0, 3.0]] + ) + TEST_CASES_AFFINE.append(case_3d) + case_4d = p(np.ones((5, 5))), p([[-1] * 5, [-1] * 5, [1] * 5, [1] * 5, [1] * 5]) + TEST_CASES_AFFINE.append(case_4d) + + +@unittest.skipUnless(has_itk, "Requires `itk` package.") +@unittest.skipUnless(has_nibabel, "Requires `nibabel` package.") +class TestITKWriter(unittest.TestCase): + def test_channel_shape(self): + with tempfile.TemporaryDirectory() as tempdir: + for c in (0, 1, 2, 3): + fname = os.path.join(tempdir, f"testing{c}.nii") + ( + ITKWriter() + .set_data_array(torch.zeros(1, 2, 3, 4), channel_dim=c, squeeze_end_dims=False) + .set_metadata({}) + .write(fname) + ) + itk_obj = itk.imread(fname) + s = [1, 2, 3, 4] + s.pop(c) + np.testing.assert_allclose(itk.size(itk_obj), s) + + @parameterized.expand(TEST_CASES_AFFINE) + def test_ras_to_lps(self, param, expected): + assert_allclose(ITKWriter.ras_to_lps(param), expected) + + +if __name__ == "__main__": + unittest.main() From df2631a234529e56ebbf0631d07ea5df3e96fba5 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 18 Jan 2022 15:22:30 +0000 Subject: [PATCH 21/33] update docstrings Signed-off-by: Wenqi Li --- monai/data/image_writer.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 7a5f88a7fb..911e45591e 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -45,7 +45,7 @@ class ImageWriter: - dimensionality of the data array, arrangements of spatial dimensions and channel/time dimensions - ``convert_to_channel_last()`` - - metadata of current affine and output affine, the data array should be converted accordingly + - metadata of the current affine and output affine, the data array should be converted accordingly - ``get_meta_info()`` - ``resample_if_needed()`` - data type of the output image @@ -53,10 +53,10 @@ class ImageWriter: Subclasses of this class should implement the backend-specific functions: - - ``set_data_array()`` to set the data array (numpy array or torch tensor) - - this calls the base class utilities for array handling + - ``set_data_array()`` to set the data array (input must be numpy array or torch tensor) + - this method sets the backend object's data part - ``set_metadata()`` to set the metadata and output affine - - this calls the base class utilities for metadata (affine, resmaple) handling + - this method sets the metadata including affine handling and image resampling - backend-specific data object - ``create_backend_obj()`` - backend-specific writing function @@ -92,10 +92,11 @@ class ImageWriter: When ``metadata`` is specified and ``resample=True``, the saver will try to resample data from the space defined by `"affine"` to the space defined by `"original_affine"`, for more details, please refer to the - ``convert_to_target_affine`` method. + ``resample_if_needed`` method. """ def __init__(self, **kwargs): + """the constructor supports adding new instance members.""" self.data_obj = None for k, v in kwargs.items(): setattr(self, k, v) @@ -115,7 +116,7 @@ def write(self, filename_or_obj: PathLike, verbose: bool = True, **kwargs): def create_backend_obj(cls, data_array: NdarrayOrTensor, **kwargs) -> np.ndarray: """ Subclass should implement this method to return a backend-specific data representation object. - This method is used by ``cls.create_data_obj`` and the input ``data_array`` is 'channel-last'. + This method is used by ``cls.write`` and the input ``data_array`` is assumed 'channel-last'. """ return convert_data_type(data_array, np.ndarray)[0] # type: ignore @@ -313,6 +314,7 @@ def __init__(self, output_dtype: DtypeLike = np.float32, **kwargs): super().__init__(output_dtype=output_dtype, affine=None, channel_dim=0, **kwargs) def set_data_array(self, data_array, channel_dim: Optional[int] = 0, squeeze_end_dims: bool = True, **kwargs): + """Convert ``data_array`` into 'channel-last' numpy ndarray.""" self.data_obj = self.convert_to_channel_last( data=data_array, channel_dim=channel_dim, @@ -324,6 +326,7 @@ def set_data_array(self, data_array, channel_dim: Optional[int] = 0, squeeze_end return self def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = True, **options): + """Resample ``self.dataobj`` if needed. This method assumes ``self.data_obj`` is a 'channel-last' ndarray.""" original_affine, affine, spatial_shape = self.get_meta_info(meta_dict) self.data_obj, self.affine = self.resample_if_needed( data_array=self.data_obj, @@ -338,6 +341,7 @@ def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = Tru return self def write(self, filename_or_obj: PathLike, verbose: bool = False, **kwargs): + """Create an ITK object from ``self.data_obj`` and call ``itk.imwrite``""" super().write(filename_or_obj, verbose=verbose) self.data_obj = self.create_backend_obj( self.data_obj, channel_dim=self.channel_dim, affine=self.affine, dtype=self.output_dtype, **kwargs # type: ignore @@ -358,8 +362,7 @@ def create_backend_obj( dtype: DtypeLike = np.float32, **kwargs, ): - """create an ITK object from ``data_array``. This method assumes a 'channel-last' ``data_array``.""" - + """Create an ITK object from ``data_array``. This method assumes a 'channel-last' ``data_array``.""" data_array = super().create_backend_obj(data_array) _is_vec = channel_dim is not None if _is_vec: From 4419803d3d4c4a7cb8596f5e766b5a749aad816a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 18 Jan 2022 21:35:20 +0000 Subject: [PATCH 22/33] remove return self Signed-off-by: Wenqi Li --- monai/data/image_writer.py | 2 -- tests/test_itk_writer.py | 10 ++++------ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 911e45591e..297600fe53 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -323,7 +323,6 @@ def set_data_array(self, data_array, channel_dim: Optional[int] = 0, squeeze_end contiguous=kwargs.pop("contiguous", True), ) self.channel_dim = channel_dim - return self def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = True, **options): """Resample ``self.dataobj`` if needed. This method assumes ``self.data_obj`` is a 'channel-last' ndarray.""" @@ -338,7 +337,6 @@ def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = Tru align_corners=options.pop("align_corners", False), dtype=options.pop("dtype", np.float64), ) - return self def write(self, filename_or_obj: PathLike, verbose: bool = False, **kwargs): """Create an ITK object from ``self.data_obj`` and call ``itk.imwrite``""" diff --git a/tests/test_itk_writer.py b/tests/test_itk_writer.py index 3ce696f0ac..a5179d82ea 100644 --- a/tests/test_itk_writer.py +++ b/tests/test_itk_writer.py @@ -49,12 +49,10 @@ def test_channel_shape(self): with tempfile.TemporaryDirectory() as tempdir: for c in (0, 1, 2, 3): fname = os.path.join(tempdir, f"testing{c}.nii") - ( - ITKWriter() - .set_data_array(torch.zeros(1, 2, 3, 4), channel_dim=c, squeeze_end_dims=False) - .set_metadata({}) - .write(fname) - ) + itk_writer = ITKWriter() + itk_writer.set_data_array(torch.zeros(1, 2, 3, 4), channel_dim=c, squeeze_end_dims=False) + itk_writer.set_metadata({}) + itk_writer.write(fname) itk_obj = itk.imread(fname) s = [1, 2, 3, 4] s.pop(c) From 4f27d0d0c7ad4f971a653e0e269e315001d0a746 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 18 Jan 2022 22:08:28 +0000 Subject: [PATCH 23/33] adds reorient_spatial_axes Signed-off-by: Wenqi Li --- monai/data/__init__.py | 1 + monai/data/image_writer.py | 9 ++------- monai/data/utils.py | 17 +++++++++++++++++ tests/test_itk_writer.py | 1 - 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 78358ff030..f37241e937 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -66,6 +66,7 @@ partition_dataset_classes, pickle_hashing, rectify_header_sform_qform, + reorient_spatial_axes, resample_datalist, select_cross_validation_folds, set_rnd, diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 297600fe53..33d13ce336 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -16,7 +16,7 @@ from monai.apps.utils import get_logger from monai.config import DtypeLike, NdarrayOrTensor, PathLike -from monai.data.utils import compute_shape_offset, ensure_mat44, to_affine_nd +from monai.data.utils import compute_shape_offset, ensure_mat44, reorient_spatial_axes, to_affine_nd from monai.networks.layers import AffineTransform from monai.transforms.utils_pytorch_numpy_unification import ascontiguousarray, moveaxis from monai.utils import GridSampleMode, GridSamplePadMode, convert_data_type, optional_import, require_pkg @@ -195,12 +195,7 @@ def resample_if_needed( # resolve orientation if has_nib: # this is to avoid dependency on nibabel - start_ornt = nib.orientations.io_orientation(affine) - target_ornt = nib.orientations.io_orientation(target_affine) - ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) - data_shape = data.shape - data = nib.orientations.apply_orientation(data, ornt_transform) - _affine = affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) + data, _affine = reorient_spatial_axes(data, affine, target_affine) if np.allclose(_affine, target_affine, atol=AFFINE_TOL): return data, ensure_mat44(_affine) diff --git a/monai/data/utils.py b/monai/data/utils.py index 9fb609d2cc..663e28f270 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -79,6 +79,7 @@ "no_collation", "convert_tables_to_dicts", "SUPPORTED_PICKLE_MOD", + "reorient_spatial_axes", ] # module to be used by `torch.save` @@ -744,6 +745,22 @@ def ensure_mat44(affine, dtype=np.float64) -> np.ndarray: return to_affine_nd(r=3, affine=affine, dtype=dtype) +def reorient_spatial_axes(data_array: np.ndarray, init_affine: np.ndarray, target_affine: np.ndarray): + """ + Given the input ``data_array`` and its corresponding coordinate ``init_affine``, + convert the array to ``target_affine`` by rearranging/flipping the axes. + Returns the transformed array and the updated affine. + + Note that this function requires external module ``nibabel.orientations``. + """ + start_ornt = nib.orientations.io_orientation(init_affine) + target_ornt = nib.orientations.io_orientation(target_affine) + ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) + new_affine = init_affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_array.shape) + data = nib.orientations.apply_orientation(data_array, ornt_transform) + return data, new_affine + + def create_file_basename( postfix: str, input_file_name: PathLike, diff --git a/tests/test_itk_writer.py b/tests/test_itk_writer.py index a5179d82ea..9913e78460 100644 --- a/tests/test_itk_writer.py +++ b/tests/test_itk_writer.py @@ -43,7 +43,6 @@ @unittest.skipUnless(has_itk, "Requires `itk` package.") -@unittest.skipUnless(has_nibabel, "Requires `nibabel` package.") class TestITKWriter(unittest.TestCase): def test_channel_shape(self): with tempfile.TemporaryDirectory() as tempdir: From 9f2d31189c40995382d4bd52890882a1072531b0 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 20 Jan 2022 10:17:04 +0000 Subject: [PATCH 24/33] update based on comments Signed-off-by: Wenqi Li --- monai/data/image_writer.py | 65 ++++++++++++++++---------------------- monai/data/utils.py | 22 +++++++++++-- tests/test_itk_writer.py | 23 -------------- tests/test_ori_ras_lps.py | 46 +++++++++++++++++++++++++++ 4 files changed, 93 insertions(+), 63 deletions(-) create mode 100644 tests/test_ori_ras_lps.py diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 33d13ce336..1f8f75dba1 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -16,7 +16,14 @@ from monai.apps.utils import get_logger from monai.config import DtypeLike, NdarrayOrTensor, PathLike -from monai.data.utils import compute_shape_offset, ensure_mat44, reorient_spatial_axes, to_affine_nd +from monai.data.utils import ( + compute_shape_offset, + ensure_mat44, + ensure_tuple, + orientation_ras_lps, + reorient_spatial_axes, + to_affine_nd, +) from monai.networks.layers import AffineTransform from monai.transforms.utils_pytorch_numpy_unification import ascontiguousarray, moveaxis from monai.utils import GridSampleMode, GridSamplePadMode, convert_data_type, optional_import, require_pkg @@ -195,14 +202,14 @@ def resample_if_needed( # resolve orientation if has_nib: # this is to avoid dependency on nibabel - data, _affine = reorient_spatial_axes(data, affine, target_affine) - if np.allclose(_affine, target_affine, atol=AFFINE_TOL): - return data, ensure_mat44(_affine) + data, affine = reorient_spatial_axes(data, affine, target_affine) + if np.allclose(affine, target_affine, atol=AFFINE_TOL): + return data, ensure_mat44(affine) # need resampling dtype = dtype or data.dtype # type: ignore if output_spatial_shape is None: - output_spatial_shape, _ = compute_shape_offset(data.shape, _affine, target_affine) + output_spatial_shape, _ = compute_shape_offset(data.shape, affine, target_affine) output_spatial_shape_ = list(output_spatial_shape) if output_spatial_shape is not None else [] sp_dims = min(data.ndim, 3) output_spatial_shape_ += [1] * (sp_dims - len(output_spatial_shape_)) @@ -218,7 +225,7 @@ def resample_if_needed( ) data_torch = affine_xform( torch.as_tensor(np.ascontiguousarray(data_np, dtype=dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(np.linalg.inv(_affine) @ target_affine, dtype=dtype)), + torch.as_tensor(np.ascontiguousarray(np.linalg.inv(affine) @ target_affine, dtype=dtype)), spatial_size=output_spatial_shape_, ) data_np = data_torch[0].detach().cpu().numpy() @@ -233,24 +240,25 @@ def resample_if_needed( def convert_to_channel_last( cls, data: NdarrayOrTensor, - channel_dim: Optional[int] = 0, + channel_dim: Union[None, int, Sequence[int]] = 0, squeeze_end_dims: bool = True, spatial_ndim: Optional[int] = 3, contiguous: bool = False, ): """ Rearrange the data array axes to make the `channel_dim`-th dim the last - dimension and ensure there are three spatial dimensions. If - ``channel_dim`` is ``None``, a new axis will be appended to the last - dimension. + dimension and ensure there are ``spatial_ndim`` number of spatial + dimensions. If ``channel_dim`` is ``None``, a new axis will be appended + to the last dimension. When ``squeeze_end_dims`` is ``True``, a postprocessing step will be applied to remove any trailing singleton dimensions. Args: data: input data to be converted to "channel-last" format. - channel_dim: specifies the axis of the data array that is the channel dimension. - ``None`` indicates no channel dimension. + channel_dim: specifies the channel axes of the data array to move to the last. + ``None`` indicates no channel dimension, + a sequence of integers indicates multiple non-spatial dimensions. squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed (after the channel has been moved to the end). So if input is `(H,W,D,C)` and C==1, then it will be saved as `(H,W,D)`. If D is also 1, it will be saved as `(H,W)`. If ``False``, image will always be saved as `(H,W,D,C)`. @@ -261,12 +269,11 @@ def convert_to_channel_last( """ # change data to "channel last" format if channel_dim is not None: - # _chns = ensure_tuple_rep(channel_dim) - # data = moveaxis(data, _chns, tuple(range(-len(_chns), 0))) - data = moveaxis(data, channel_dim, -1) + _chns = ensure_tuple(channel_dim) + data = moveaxis(data, _chns, tuple(range(-len(_chns), 0))) else: # adds a channel dimension data = data[..., None] - # To ensure at least three spatial dims + # To ensure at least ``spatial_ndim`` number of spatial dims if spatial_ndim: while len(data.shape) < spatial_ndim + 1: # assuming the data has spatial + channel dims data = data[..., None, :] @@ -291,8 +298,7 @@ def get_meta_info(cls, metadata: Optional[Mapping] = None): - ``'spatial_shape'``: for data original spatial shape. """ if not metadata: - default_dict = {"original_affine": None, "affine": None, "spatial_shape": None} - metadata = default_dict + metadata = {"original_affine": None, "affine": None, "spatial_shape": None} original_affine = metadata.get("original_affine") affine = metadata.get("affine") spatial_shape = metadata.get("spatial_shape") @@ -364,32 +370,15 @@ def create_backend_obj( itk_obj = itk.GetImageFromArray(data_array, is_vector=_is_vec, ttype=kwargs.pop("ttype", None)) d = len(itk.size(itk_obj)) - # convert affine to LPS if affine is None: affine = np.eye(d + 1, dtype=np.float64) - _affine: np.ndarray = convert_data_type(affine, np.ndarray)[0] # type: ignore - _affine = cls.ras_to_lps(to_affine_nd(d, _affine)) + _affine = convert_data_type(affine, np.ndarray)[0] + _affine = orientation_ras_lps(to_affine_nd(d, _affine)) spacing = np.sqrt(np.sum(np.square(_affine[:d, :d]), 0)) spacing[spacing == 0] = 1.0 _direction: np.ndarray = np.diag(1 / spacing) - _direction = _affine[:d, :d] @ _direction + _direction = _affine[:d, :d] @ _direction # type: ignore itk_obj.SetSpacing(spacing.tolist()) itk_obj.SetOrigin(_affine[:d, -1].tolist()) itk_obj.SetDirection(itk.GetMatrixFromArray(_direction)) return itk_obj - - @staticmethod - def ras_to_lps(affine: NdarrayOrTensor): - """ - Convert the ``affine`` from `RAS` to `LPS` by flipping the first two spatial dimensions. - (This could also be used to convert from `LPS` to `RAS`.) - - Args: - affine: a 2D affine matrix. - """ - sr = max(affine.shape[0] - 1, 1) # spatial rank is at least 1 - flip_d = [[-1, 1], [-1, -1, 1], [-1, -1, 1, 1]] - flip_diag = flip_d[min(sr - 1, 2)] + [1] * (sr - 3) - if isinstance(affine, torch.Tensor): - return torch.diag(torch.as_tensor(flip_diag).to(affine)) @ affine - return np.diag(flip_diag).astype(affine.dtype) @ affine diff --git a/monai/data/utils.py b/monai/data/utils.py index 663e28f270..0a16e7193f 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -27,7 +27,7 @@ import torch from torch.utils.data._utils.collate import default_collate -from monai.config.type_definitions import PathLike +from monai.config.type_definitions import NdarrayOrTensor, PathLike from monai.networks.layers.simplelayers import GaussianFilter from monai.utils import ( MAX_SEED, @@ -745,7 +745,9 @@ def ensure_mat44(affine, dtype=np.float64) -> np.ndarray: return to_affine_nd(r=3, affine=affine, dtype=dtype) -def reorient_spatial_axes(data_array: np.ndarray, init_affine: np.ndarray, target_affine: np.ndarray): +def reorient_spatial_axes( + data_array: np.ndarray, init_affine: np.ndarray, target_affine: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: """ Given the input ``data_array`` and its corresponding coordinate ``init_affine``, convert the array to ``target_affine`` by rearranging/flipping the axes. @@ -1250,3 +1252,19 @@ def convert_tables_to_dicts( data = [dict(d, **{k: v[i] for k, v in groups.items()}) for i, d in enumerate(data)] return data + + +def orientation_ras_lps(affine: NdarrayOrTensor) -> NdarrayOrTensor: + """ + Convert the ``affine`` between the `RAS` and `LPS` orientation + by flipping the first two spatial dimensions. + + Args: + affine: a 2D affine matrix. + """ + sr = max(affine.shape[0] - 1, 1) # spatial rank is at least 1 + flip_d = [[-1, 1], [-1, -1, 1], [-1, -1, 1, 1]] + flip_diag = flip_d[min(sr - 1, 2)] + [1] * (sr - 3) + if isinstance(affine, torch.Tensor): + return torch.diag(torch.as_tensor(flip_diag).to(affine)) @ affine + return np.diag(flip_diag).astype(affine.dtype) @ affine # type: ignore diff --git a/tests/test_itk_writer.py b/tests/test_itk_writer.py index 9913e78460..96231c9b69 100644 --- a/tests/test_itk_writer.py +++ b/tests/test_itk_writer.py @@ -15,32 +15,13 @@ import numpy as np import torch -from parameterized import parameterized from monai.data import ITKWriter from monai.utils import optional_import -from tests.utils import TEST_NDARRAYS, assert_allclose itk, has_itk = optional_import("itk") nib, has_nibabel = optional_import("nibabel") -TEST_CASES_AFFINE = [] -for p in TEST_NDARRAYS: - case_1d = p([[1.0, 0.0], [1.0, 1.0]]), p([[-1.0, 0.0], [1.0, 1.0]]) - TEST_CASES_AFFINE.append(case_1d) - case_2d_1 = p([[1.0, 0.0, 1.0], [1.0, 1.0, 1.0]]), p([[-1.0, 0.0, -1.0], [1.0, 1.0, 1.0]]) - TEST_CASES_AFFINE.append(case_2d_1) - case_2d_2 = p([[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]), p( - [[-1.0, 0.0, -1.0], [0.0, -1.0, -1.0], [1.0, 1.0, 1.0]] - ) - TEST_CASES_AFFINE.append(case_2d_2) - case_3d = p([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 2.0], [1.0, 1.0, 1.0, 3.0]]), p( - [[-1.0, 0.0, -1.0, -1.0], [0.0, -1.0, -1.0, -2.0], [1.0, 1.0, 1.0, 3.0]] - ) - TEST_CASES_AFFINE.append(case_3d) - case_4d = p(np.ones((5, 5))), p([[-1] * 5, [-1] * 5, [1] * 5, [1] * 5, [1] * 5]) - TEST_CASES_AFFINE.append(case_4d) - @unittest.skipUnless(has_itk, "Requires `itk` package.") class TestITKWriter(unittest.TestCase): @@ -57,10 +38,6 @@ def test_channel_shape(self): s.pop(c) np.testing.assert_allclose(itk.size(itk_obj), s) - @parameterized.expand(TEST_CASES_AFFINE) - def test_ras_to_lps(self, param, expected): - assert_allclose(ITKWriter.ras_to_lps(param), expected) - if __name__ == "__main__": unittest.main() diff --git a/tests/test_ori_ras_lps.py b/tests/test_ori_ras_lps.py new file mode 100644 index 0000000000..4ed223bf5b --- /dev/null +++ b/tests/test_ori_ras_lps.py @@ -0,0 +1,46 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.data.utils import orientation_ras_lps +from tests.utils import TEST_NDARRAYS, assert_allclose + +TEST_CASES_AFFINE = [] +for p in TEST_NDARRAYS: + case_1d = p([[1.0, 0.0], [1.0, 1.0]]), p([[-1.0, 0.0], [1.0, 1.0]]) + TEST_CASES_AFFINE.append(case_1d) + case_2d_1 = p([[1.0, 0.0, 1.0], [1.0, 1.0, 1.0]]), p([[-1.0, 0.0, -1.0], [1.0, 1.0, 1.0]]) + TEST_CASES_AFFINE.append(case_2d_1) + case_2d_2 = p([[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]), p( + [[-1.0, 0.0, -1.0], [0.0, -1.0, -1.0], [1.0, 1.0, 1.0]] + ) + TEST_CASES_AFFINE.append(case_2d_2) + case_3d = p([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 2.0], [1.0, 1.0, 1.0, 3.0]]), p( + [[-1.0, 0.0, -1.0, -1.0], [0.0, -1.0, -1.0, -2.0], [1.0, 1.0, 1.0, 3.0]] + ) + TEST_CASES_AFFINE.append(case_3d) + case_4d = p(np.ones((5, 5))), p([[-1] * 5, [-1] * 5, [1] * 5, [1] * 5, [1] * 5]) + TEST_CASES_AFFINE.append(case_4d) + + +class TestITKWriter(unittest.TestCase): + @parameterized.expand(TEST_CASES_AFFINE) + def test_ras_to_lps(self, param, expected): + assert_allclose(orientation_ras_lps(param), expected) + + +if __name__ == "__main__": + unittest.main() From f2a2ea9a0f4f6a22f4505a11a681abaa9737fede Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 27 Jan 2022 15:24:43 +0000 Subject: [PATCH 25/33] update based on comments Signed-off-by: Wenqi Li --- monai/data/image_writer.py | 59 +++++++++++++++++++++++++++++++++----- 1 file changed, 52 insertions(+), 7 deletions(-) diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 1f8f75dba1..1232fd8d6c 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -103,7 +103,11 @@ class ImageWriter: """ def __init__(self, **kwargs): - """the constructor supports adding new instance members.""" + """ + The constructor supports adding new instance members. + The current member in the base class is ``self.data_obj``, the subclases can add more members, + so that necessary meta information can be stored in the object and shared among the class methods. + """ self.data_obj = None for k, v in kwargs.items(): setattr(self, k, v) @@ -311,11 +315,27 @@ class ITKWriter(ImageWriter): Write data and metadata into files on disk using ITK-python. """ - def __init__(self, output_dtype: DtypeLike = np.float32, **kwargs): - super().__init__(output_dtype=output_dtype, affine=None, channel_dim=0, **kwargs) + def __init__(self, output_dtype: DtypeLike = np.float32, affine=None, channel_dim=0, **kwargs): + """ + Args: + output_dtype: output data type. + kwargs: keyword arguments passed to ``ImageWriter``. + + The constructor will create ``self.output_dtype``, ``self.affine``, ``self.channel_dim`` internally. + """ + super().__init__(output_dtype=output_dtype, affine=affine, channel_dim=channel_dim, **kwargs) def set_data_array(self, data_array, channel_dim: Optional[int] = 0, squeeze_end_dims: bool = True, **kwargs): - """Convert ``data_array`` into 'channel-last' numpy ndarray.""" + """ + Convert ``data_array`` into 'channel-last' numpy ndarray. + + Args: + data_array: input data array with the channel dimension specified by ``channel_dim``. + channel_dim: channel dimension of the data array. + squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed. + kwargs: keyword arguments passed to ``self.convert_to_channel_last``, + currently support ``spatial_ndim`` and ``contiguous``. + """ self.data_obj = self.convert_to_channel_last( data=data_array, channel_dim=channel_dim, @@ -326,7 +346,15 @@ def set_data_array(self, data_array, channel_dim: Optional[int] = 0, squeeze_end self.channel_dim = channel_dim def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = True, **options): - """Resample ``self.dataobj`` if needed. This method assumes ``self.data_obj`` is a 'channel-last' ndarray.""" + """ + Resample ``self.dataobj`` if needed. This method assumes ``self.data_obj`` is a 'channel-last' ndarray. + + Args: + meta_dict: a metadata dictionary for affine, original affine and spatial shape information. + resample: if ``True``, the data will be resampled to the original affine (specified in ``meta_dict``). + options: keyword arguments passed to ``self.resample_if_needed``, + currently support ``mode``, ``padding_mode``, ``align_corners``, and ``dtype``. + """ original_affine, affine, spatial_shape = self.get_meta_info(meta_dict) self.data_obj, self.affine = self.resample_if_needed( data_array=self.data_obj, @@ -340,7 +368,15 @@ def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = Tru ) def write(self, filename_or_obj: PathLike, verbose: bool = False, **kwargs): - """Create an ITK object from ``self.data_obj`` and call ``itk.imwrite``""" + """ + Create an ITK object from ``self.data_obj`` and call ``itk.imwrite``. + + Args: + filename_or_obj: filename or PathLike object. + verbose: if ``True``, log the progress. + kwargs: keyword arguments passed to ``itk.imwrite``, + currently support ``compression`` and ``imageio``. + """ super().write(filename_or_obj, verbose=verbose) self.data_obj = self.create_backend_obj( self.data_obj, channel_dim=self.channel_dim, affine=self.affine, dtype=self.output_dtype, **kwargs # type: ignore @@ -361,7 +397,16 @@ def create_backend_obj( dtype: DtypeLike = np.float32, **kwargs, ): - """Create an ITK object from ``data_array``. This method assumes a 'channel-last' ``data_array``.""" + """ + Create an ITK object from ``data_array``. This method assumes a 'channel-last' ``data_array``. + + Args: + data_array: input data array. + channel_dim: channel dimension of the data array. This is used to create a Vector Image if it is not ``None``. + affine: affine matrix of the data array. This is used to compute `spacing`, `direction` and `origin`. + dtype: output data type. + kwargs: keyword arguments. Current `itk.GetImageFromArray` will read ``ttype`` from this dictionary. + """ data_array = super().create_backend_obj(data_array) _is_vec = channel_dim is not None if _is_vec: From bc9c705a0da641f910295d55814d17074d0868de Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 27 Jan 2022 17:35:31 +0000 Subject: [PATCH 26/33] fixes unit tests Signed-off-by: Wenqi Li --- tests/test_spatial_resample.py | 13 +++++++------ tests/test_spatial_resampled.py | 14 ++++++++------ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/tests/test_spatial_resample.py b/tests/test_spatial_resample.py index 54393af8d5..8aa14fe242 100644 --- a/tests/test_spatial_resample.py +++ b/tests/test_spatial_resample.py @@ -61,10 +61,11 @@ for p_src in TEST_NDARRAYS: for dtype in (np.float32, np.float64): for align in (True, False): - interp = ("nearest", "bilinear") if align and USE_COMPILED: - interp = interp + (0, 1) # type: ignore - for interp_mode in interp: + interp = ("nearest", "bilinear", 0, 1) # type: ignore + else: + interp = ("nearest", "bilinear") # type: ignore + for interp_mode in interp: # type: ignore for padding_mode in ("zeros", "border", "reflection"): TESTS.append( [ @@ -93,8 +94,8 @@ class TestSpatialResample(unittest.TestCase): @parameterized.expand(itertools.product(TEST_NDARRAYS, TESTS)) def test_flips(self, p_type, args): init_param, img, data_param, expected_output = args - _img = p(img) - _expected_output = p(expected_output) + _img = p_type(img) + _expected_output = p_type(expected_output) output_data, output_dst = SpatialResample(**init_param)(img=_img, **data_param) assert_allclose(output_data, _expected_output) expected_dst = data_param.get("dst") if data_param.get("dst") is not None else data_param.get("src") @@ -107,7 +108,7 @@ def test_4d_5d(self, is_5d, p_type): img = np.tile(img, (1, 1, 1, 1, 2, 2) if is_5d else (1, 1, 1, 1, 2)) _img = p_type(img) dst = np.asarray([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.5], [0.0, 0.0, 0.0, 1.0]]) - output_data, output_dst = SpatialResample(dtype=np.float32)(img=_img, src=p(np.eye(4)), dst=dst) + output_data, output_dst = SpatialResample(dtype=np.float32)(img=_img, src=p_type(np.eye(4)), dst=dst) expected_data = ( np.asarray( [ diff --git a/tests/test_spatial_resampled.py b/tests/test_spatial_resampled.py index 2b6ca17b84..9515bfe48a 100644 --- a/tests/test_spatial_resampled.py +++ b/tests/test_spatial_resampled.py @@ -61,10 +61,11 @@ for p_src in TEST_NDARRAYS: for dtype in (np.float32, np.float64): for align in (True, False): - interp = ("nearest", "bilinear") if align and USE_COMPILED: - interp = interp + (0, 1) # type: ignore - for interp_mode in interp: + interp = ("nearest", "bilinear", 0, 1) # type: ignore + else: + interp = ("nearest", "bilinear") # type: ignore + for interp_mode in interp: # type: ignore for padding_mode in ("zeros", "border", "reflection"): TESTS.append( [ @@ -92,9 +93,9 @@ class TestSpatialResample(unittest.TestCase): @parameterized.expand(itertools.product(TEST_NDARRAYS, TESTS)) def test_flips_inverse(self, p_type, args): - init_param, img, data_param, expected_output = args - _img = p(img) - _expected_output = p(expected_output) + _, img, data_param, expected_output = args + _img = p_type(img) + _expected_output = p_type(expected_output) input_dict = {"img": _img, "img_meta_dict": {"src": data_param.get("src"), "dst": data_param.get("dst")}} xform = SpatialResampleD( keys="img", @@ -102,6 +103,7 @@ def test_flips_inverse(self, p_type, args): meta_dst_keys="dst", mode=data_param.get("mode"), padding_mode=data_param.get("padding_mode"), + align_corners=data_param.get("align_corners"), ) output_data = xform(input_dict) assert_allclose(output_data["img"], _expected_output) From f96614b275b12e74ebcbf7666300a9c687ed4370 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 27 Jan 2022 22:11:33 +0000 Subject: [PATCH 27/33] sync 3701 Signed-off-by: Wenqi Li --- monai/data/image_writer.py | 73 ++++--------------------------- monai/transforms/spatial/array.py | 2 +- 2 files changed, 10 insertions(+), 65 deletions(-) diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 1232fd8d6c..9da9bb4104 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -12,19 +12,11 @@ from typing import TYPE_CHECKING, Mapping, Optional, Sequence, Union import numpy as np -import torch from monai.apps.utils import get_logger from monai.config import DtypeLike, NdarrayOrTensor, PathLike -from monai.data.utils import ( - compute_shape_offset, - ensure_mat44, - ensure_tuple, - orientation_ras_lps, - reorient_spatial_axes, - to_affine_nd, -) -from monai.networks.layers import AffineTransform +from monai.data.utils import ensure_tuple, orientation_ras_lps, to_affine_nd +from monai.transforms.spatial.array import SpatialResample from monai.transforms.utils_pytorch_numpy_unification import ascontiguousarray, moveaxis from monai.utils import GridSampleMode, GridSamplePadMode, convert_data_type, optional_import, require_pkg @@ -34,12 +26,8 @@ if TYPE_CHECKING: import itk # type: ignore - import nibabel as nib - - has_itk = has_nib = has_pil = True else: - itk, has_itk = optional_import("itk", allow_namespace_pkg=True) - nib, has_nib = optional_import("nibabel") + itk, _ = optional_import("itk", allow_namespace_pkg=True) __all__ = ["ImageWriter", "ITKWriter"] @@ -136,8 +124,8 @@ def resample_if_needed( cls, data_array: NdarrayOrTensor, affine: Optional[NdarrayOrTensor] = None, - target_affine: Optional[np.ndarray] = None, - output_spatial_shape: Union[Sequence[int], np.ndarray, None] = None, + target_affine: Optional[NdarrayOrTensor] = None, + output_spatial_shape: Union[Sequence[int], np.ndarray, int, None] = None, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, align_corners: bool = False, @@ -191,54 +179,11 @@ def resample_if_needed( dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If ``None``, use the data type of input data. """ - data: np.ndarray - data, *_ = convert_data_type(data_array, np.ndarray) # type: ignore - - sr = min(data.ndim, 3) - if affine is None: - affine = np.eye(4, dtype=np.float64) - affine = to_affine_nd(sr, affine) # type: ignore - target_affine = to_affine_nd(sr, target_affine) if target_affine is not None else affine - - if np.allclose(affine, target_affine, atol=AFFINE_TOL): - # no affine changes, return (data, affine) - return data, ensure_mat44(target_affine) - - # resolve orientation - if has_nib: # this is to avoid dependency on nibabel - data, affine = reorient_spatial_axes(data, affine, target_affine) - if np.allclose(affine, target_affine, atol=AFFINE_TOL): - return data, ensure_mat44(affine) - - # need resampling - dtype = dtype or data.dtype # type: ignore - if output_spatial_shape is None: - output_spatial_shape, _ = compute_shape_offset(data.shape, affine, target_affine) - output_spatial_shape_ = list(output_spatial_shape) if output_spatial_shape is not None else [] - sp_dims = min(data.ndim, 3) - output_spatial_shape_ += [1] * (sp_dims - len(output_spatial_shape_)) - output_spatial_shape_ = output_spatial_shape_[:sp_dims] - original_channels = data.shape[3:] - if original_channels: # multi channel, resampling each channel - data_np: np.ndarray = data.reshape(list(data.shape[:3]) + [-1]) # type: ignore - data_np = np.moveaxis(data_np, -1, 0) # channel first for pytorch - else: # single channel image, need to expand to have a channel - data_np = data[None] - affine_xform = AffineTransform( - normalized=False, mode=mode, padding_mode=padding_mode, align_corners=align_corners, reverse_indexing=True - ) - data_torch = affine_xform( - torch.as_tensor(np.ascontiguousarray(data_np, dtype=dtype)).unsqueeze(0), - torch.as_tensor(np.ascontiguousarray(np.linalg.inv(affine) @ target_affine, dtype=dtype)), - spatial_size=output_spatial_shape_, + resampler = SpatialResample(mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype) + output_array, target_affine = resampler( + data_array[None], src=affine, dst=target_affine, spatial_size=output_spatial_shape ) - data_np = data_torch[0].detach().cpu().numpy() - if original_channels: - data_np = np.moveaxis(data_np, 0, -1) # channel last - data_np = data_np.reshape(list(data_np.shape[:3]) + list(original_channels)) - else: - data_np = data_np[0] - return data_np, ensure_mat44(target_affine) + return output_array[0], target_affine @classmethod def convert_to_channel_last( diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index a5a00944ca..52c5e892a7 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -131,7 +131,7 @@ def __call__( img: NdarrayOrTensor, src: Optional[NdarrayOrTensor] = None, dst: Optional[NdarrayOrTensor] = None, - spatial_size: Optional[Union[Sequence[int], int]] = None, + spatial_size: Optional[Union[Sequence[int], np.ndarray, int]] = None, mode: Union[GridSampleMode, str, None] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str, None] = GridSamplePadMode.BORDER, align_corners: Optional[bool] = False, From b32336e596009127c46c813149296fa835697471 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 4 Feb 2022 14:02:57 +0000 Subject: [PATCH 28/33] try to fix #3766 Signed-off-by: Wenqi Li --- monai/data/image_writer.py | 4 ++-- monai/data/utils.py | 33 ++++++++++++++++++++++---- monai/transforms/spatial/dictionary.py | 3 ++- tests/test_spacing.py | 3 ++- 4 files changed, 34 insertions(+), 9 deletions(-) diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index babdfb8c23..ced6a316be 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -15,7 +15,7 @@ from monai.apps.utils import get_logger from monai.config import DtypeLike, NdarrayOrTensor, PathLike -from monai.data.utils import ensure_tuple, orientation_ras_lps, to_affine_nd +from monai.data.utils import affine_to_spacing, ensure_tuple, orientation_ras_lps, to_affine_nd from monai.transforms.spatial.array import SpatialResample from monai.transforms.utils_pytorch_numpy_unification import ascontiguousarray, moveaxis from monai.utils import GridSampleMode, GridSamplePadMode, convert_data_type, optional_import, require_pkg @@ -364,7 +364,7 @@ def create_backend_obj( affine = np.eye(d + 1, dtype=np.float64) _affine = convert_data_type(affine, np.ndarray)[0] _affine = orientation_ras_lps(to_affine_nd(d, _affine)) - spacing = np.sqrt(np.sum(np.square(_affine[:d, :d]), 0)) + spacing = affine_to_spacing(_affine, dims=d) spacing[spacing == 0] = 1.0 _direction: np.ndarray = np.diag(1 / spacing) _direction = _affine[:d, :d] @ _direction # type: ignore diff --git a/monai/data/utils.py b/monai/data/utils.py index 8b6e2f6c28..f50923653e 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -60,6 +60,7 @@ "list_data_collate", "worker_init_fn", "set_rnd", + "affine_to_spacing", "correct_nifti_header_if_necessary", "rectify_header_sform_qform", "zoom_affine", @@ -547,6 +548,28 @@ def set_rnd(obj, seed: int) -> int: return seed +def affine_to_spacing(affine, dims: int = 3, dtype=float, suppress_zeros: bool = True): + """ + Computing the current spacing from the affine matrix. + + Args: + affine: a d x d affine matrix. + dims: indexing of the spatial dimensions `affine[:dims, :dims]`. + dtype: data type of the output. + suppress_zeros: whether to surpress the zeros with ones. + + Returns: + a `dims` dimensional vector of spacing. + """ + _affine, *_ = convert_to_dst_type(affine[:dims, :dims], dst=affine, dtype=dtype) + if isinstance(_affine, torch.Tensor): + spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0)) + spacing = np.sqrt(np.sum(_affine * _affine, axis=0)) + if suppress_zeros: + spacing[spacing == 0] = 1.0 + return spacing + + def correct_nifti_header_if_necessary(img_nii): """ Check nifti object header's format, update the header if needed. @@ -562,7 +585,7 @@ def correct_nifti_header_if_necessary(img_nii): return img_nii # do nothing for high-dimensional array # check that affine matches zooms pixdim = np.asarray(img_nii.header.get_zooms())[:dim] - norm_affine = np.sqrt(np.sum(np.square(img_nii.affine[:dim, :dim]), 0)) + norm_affine = affine_to_spacing(img_nii.affine, dims=dim) if np.allclose(pixdim, norm_affine): return img_nii if hasattr(img_nii, "get_sform"): @@ -583,8 +606,8 @@ def rectify_header_sform_qform(img_nii): d = img_nii.header["dim"][0] pixdim = np.asarray(img_nii.header.get_zooms())[:d] sform, qform = img_nii.get_sform(), img_nii.get_qform() - norm_sform = np.sqrt(np.sum(np.square(sform[:d, :d]), 0)) - norm_qform = np.sqrt(np.sum(np.square(qform[:d, :d]), 0)) + norm_sform = affine_to_spacing(sform, dims=d) + norm_qform = affine_to_spacing(qform, dims=d) sform_mismatch = not np.allclose(norm_sform, pixdim) qform_mismatch = not np.allclose(norm_qform, pixdim) @@ -601,7 +624,7 @@ def rectify_header_sform_qform(img_nii): img_nii.set_qform(img_nii.get_sform()) return img_nii - norm = np.sqrt(np.sum(np.square(img_nii.affine[:d, :d]), 0)) + norm = affine_to_spacing(img_nii.affine, dims=d) warnings.warn(f"Modifying image pixdim from {pixdim} to {norm}") img_nii.header.set_zooms(norm) @@ -641,7 +664,7 @@ def zoom_affine(affine: np.ndarray, scale: Union[np.ndarray, Sequence[float]], d d = len(affine) - 1 # compute original pixdim - norm = np.sqrt(np.sum(np.square(affine), 0))[:-1] + norm = affine_to_spacing(affine, dims=d) if len(scale_np) < d: # defaults based on affine scale_np = np.append(scale_np, norm[len(scale_np) :]) scale_np = scale_np[:d] diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index b319bb44b1..11381dc5c1 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -24,6 +24,7 @@ from monai.config import DtypeLike, KeysCollection from monai.config.type_definitions import NdarrayOrTensor +from monai.data.utils import affine_to_spacing from monai.networks.layers import AffineTransform from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad @@ -435,7 +436,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] orig_size = transform[TraceKeys.ORIG_SIZE] - orig_pixdim = np.sqrt(np.sum(np.square(old_affine), 0))[:-1] + orig_pixdim = affine_to_spacing(old_affine, -1) inverse_transform = Spacing(orig_pixdim, diagonal=self.spacing_transform.diagonal) # Apply inverse d[key], _, new_affine = inverse_transform( diff --git a/tests/test_spacing.py b/tests/test_spacing.py index 0dd10a54f3..80df981b73 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -15,6 +15,7 @@ import torch from parameterized import parameterized +from monai.data.utils import affine_to_spacing from monai.transforms import Spacing from monai.utils import ensure_tuple, fall_back_tuple from tests.utils import TEST_NDARRAYS @@ -218,7 +219,7 @@ def test_spacing(self, in_type, init_param, img, data_param, expected_output): init_param["pixdim"] = [init_param["pixdim"]] * sr init_pixdim = ensure_tuple(init_param["pixdim"]) init_pixdim = init_param["pixdim"][:sr] - norm = np.sqrt(np.sum(np.square(new_affine), axis=0))[:sr] + norm = affine_to_spacing(new_affine, sr) np.testing.assert_allclose(fall_back_tuple(init_pixdim, norm), norm) From dbf44ae0777d5a4b8ed087bd7e4dd32694df2ea7 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 4 Feb 2022 17:28:13 +0000 Subject: [PATCH 29/33] revise docstring to be concise Signed-off-by: Wenqi Li --- monai/data/image_writer.py | 41 ++++++++++++++------------------------ 1 file changed, 15 insertions(+), 26 deletions(-) diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 65518c2e0c..88a0b17555 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -43,8 +43,7 @@ class ImageWriter: - metadata of the current affine and output affine, the data array should be converted accordingly - ``get_meta_info()`` - ``resample_if_needed()`` - - data type of the output image - - as part of ``resample_if_needed()`` + - data type handling of the output image (as part of ``resample_if_needed()``) Subclasses of this class should implement the backend-specific functions: @@ -52,12 +51,10 @@ class ImageWriter: - this method sets the backend object's data part - ``set_metadata()`` to set the metadata and output affine - this method sets the metadata including affine handling and image resampling - - backend-specific data object - - ``create_backend_obj()`` - - backend-specific writing function - - ``write()`` + - backend-specific data object ``create_backend_obj()`` + - backend-specific writing function ``write()`` - The primary usage is: + The primary usage of subclasses of ``ImageWriter`` is: .. code-block:: python @@ -66,13 +63,12 @@ class ImageWriter: writer.set_metadata(meta_dict) writer.write(filename) - Create an image writer object based on ``data_array`` and ``metadata``. - Spatially it supports up to three dimensions (with the resampling step - supports both 2D and 3D). + This creates an image writer object based on ``data_array`` and ``meta_dict`` and write to ``filename``. + It supports up to three spatial dimensions (with the resampling step supports for both 2D and 3D). When saving multiple time steps or multiple channels `data_array`, time and/or modality axes should be the at the `channel_dim`. For example, - the shape of a 2D eight-class and channel_dim=0, the segmentation + the shape of a 2D eight-class and ``channel_dim=0``, the segmentation probabilities to be saved could be `(8, 64, 64)`; in this case ``data_array`` will be converted to `(64, 64, 1, 8)` (the third dimension is reserved as a spatial dimension). @@ -197,8 +193,7 @@ def convert_to_channel_last( """ Rearrange the data array axes to make the `channel_dim`-th dim the last dimension and ensure there are ``spatial_ndim`` number of spatial - dimensions. If ``channel_dim`` is ``None``, a new axis will be appended - to the last dimension. + dimensions. When ``squeeze_end_dims`` is ``True``, a postprocessing step will be applied to remove any trailing singleton dimensions. @@ -206,7 +201,7 @@ def convert_to_channel_last( Args: data: input data to be converted to "channel-last" format. channel_dim: specifies the channel axes of the data array to move to the last. - ``None`` indicates no channel dimension, + ``None`` indicates no channel dimension, a new axis will be appended as the channel dimension. a sequence of integers indicates multiple non-spatial dimensions. squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed (after the channel has been moved to the end). So if input is `(H,W,D,C)` and C==1, then it will be saved as `(H,W,D)`. @@ -239,12 +234,7 @@ def convert_to_channel_last( def get_meta_info(cls, metadata: Optional[Mapping] = None): """ Extracts relevant meta information from the metadata object (using ``.get``). - - This method returns the following fields (the default value is ``None``): - - - ``'original_affine'``: for data original affine (before any image processing), - - ``'affine'``: for the current data affine (representing the current coordinate information), - - ``'spatial_shape'``: for data original spatial shape. + Optional keys are ``"spatial_shape"``, ``"affine"``, ``"original_affine"``. """ if not metadata: metadata = {"original_affine": None, "affine": None, "spatial_shape": None} @@ -276,7 +266,8 @@ def set_data_array(self, data_array, channel_dim: Optional[int] = 0, squeeze_end Args: data_array: input data array with the channel dimension specified by ``channel_dim``. - channel_dim: channel dimension of the data array. + channel_dim: channel dimension of the data array. Defaults to 0. + ``None`` indicates data without any channel dimension. squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed. kwargs: keyword arguments passed to ``self.convert_to_channel_last``, currently support ``spatial_ndim`` and ``contiguous``. @@ -296,6 +287,7 @@ def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = Tru Args: meta_dict: a metadata dictionary for affine, original affine and spatial shape information. + Optional keys are ``"spatial_shape"``, ``"affine"``, ``"original_affine"``. resample: if ``True``, the data will be resampled to the original affine (specified in ``meta_dict``). options: keyword arguments passed to ``self.resample_if_needed``, currently support ``mode``, ``padding_mode``, ``align_corners``, and ``dtype``. @@ -314,7 +306,7 @@ def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = Tru def write(self, filename: PathLike, verbose: bool = False, **kwargs): """ - Create an ITK object from ``self.data_obj`` and call ``itk.imwrite``. + Create an ITK object from ``self.create_backend_obj(self.obj, ...)`` and call ``itk.imwrite``. Args: filename: filename or PathLike object. @@ -327,10 +319,7 @@ def write(self, filename: PathLike, verbose: bool = False, **kwargs): self.data_obj, channel_dim=self.channel_dim, affine=self.affine, dtype=self.output_dtype, **kwargs # type: ignore ) itk.imwrite( - self.data_obj, - filename, - compression=kwargs.pop("compression", False), - imageio=kwargs.pop("imageio", None), + self.data_obj, filename, compression=kwargs.pop("compression", False), imageio=kwargs.pop("imageio", None) ) @classmethod From 1f23c040531106c51c85b14bebfc14138dfc8a0a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 5 Feb 2022 18:25:30 +0000 Subject: [PATCH 30/33] update based on comments Signed-off-by: Wenqi Li --- monai/data/__init__.py | 1 + monai/data/image_writer.py | 24 ++++++++++----- monai/data/utils.py | 49 +++++++++++++++++-------------- monai/transforms/spatial/array.py | 4 +-- 4 files changed, 45 insertions(+), 33 deletions(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index f37241e937..f03410b5ba 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -61,6 +61,7 @@ iter_patch_slices, json_hashing, list_data_collate, + orientation_ras_lps, pad_list_data_collate, partition_dataset, partition_dataset_classes, diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 88a0b17555..e38f790e1a 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -20,7 +20,6 @@ from monai.transforms.utils_pytorch_numpy_unification import ascontiguousarray, moveaxis from monai.utils import GridSampleMode, GridSamplePadMode, convert_data_type, optional_import, require_pkg -AFFINE_TOL = 1e-3 DEFAULT_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d - %(message)s" logger = get_logger(module_name=__name__, fmt=DEFAULT_FMT) @@ -80,16 +79,15 @@ class ImageWriter: - ``'affine'``: it should specify the current data affine, defaulting to an identity matrix. - ``'spatial_shape'``: for data output spatial shape. - When ``metadata`` is specified and ``resample=True``, the saver will - try to resample data from the space defined by `"affine"` to the space - defined by `"original_affine"`, for more details, please refer to the + When ``metadata`` is specified, the saver will may resample data from the space defined by + `"affine"` to the space defined by `"original_affine"`, for more details, please refer to the ``resample_if_needed`` method. """ def __init__(self, **kwargs): """ The constructor supports adding new instance members. - The current member in the base class is ``self.data_obj``, the subclases can add more members, + The current member in the base class is ``self.data_obj``, the subclasses can add more members, so that necessary meta information can be stored in the object and shared among the class methods. """ self.data_obj = None @@ -174,6 +172,7 @@ def resample_if_needed( See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If ``None``, use the data type of input data. + The output data type of this method is always ``np.float32``. """ resampler = SpatialResample(mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype) output_array, target_affine = resampler( @@ -270,7 +269,7 @@ def set_data_array(self, data_array, channel_dim: Optional[int] = 0, squeeze_end ``None`` indicates data without any channel dimension. squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed. kwargs: keyword arguments passed to ``self.convert_to_channel_last``, - currently support ``spatial_ndim`` and ``contiguous``. + currently support ``spatial_ndim`` and ``contiguous``, defauting to ``3`` and ``False`` respectively. """ self.data_obj = self.convert_to_channel_last( data=data_array, @@ -290,7 +289,8 @@ def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = Tru Optional keys are ``"spatial_shape"``, ``"affine"``, ``"original_affine"``. resample: if ``True``, the data will be resampled to the original affine (specified in ``meta_dict``). options: keyword arguments passed to ``self.resample_if_needed``, - currently support ``mode``, ``padding_mode``, ``align_corners``, and ``dtype``. + currently support ``mode``, ``padding_mode``, ``align_corners``, and ``dtype``, + defaulting to ``bilinear``, ``border``, ``False``, and ``np.float64`` respectively. """ original_affine, affine, spatial_shape = self.get_meta_info(meta_dict) self.data_obj, self.affine = self.resample_if_needed( @@ -313,6 +313,10 @@ def write(self, filename: PathLike, verbose: bool = False, **kwargs): verbose: if ``True``, log the progress. kwargs: keyword arguments passed to ``itk.imwrite``, currently support ``compression`` and ``imageio``. + + See also: + + - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L809 """ super().write(filename, verbose=verbose) self.data_obj = self.create_backend_obj( @@ -332,7 +336,7 @@ def create_backend_obj( **kwargs, ): """ - Create an ITK object from ``data_array``. This method assumes a 'channel-last' ``data_array``. + Create an ITK object from ``data_array``. This method assumes a 'channel-last' ``data_array``. Args: data_array: input data array. @@ -340,6 +344,10 @@ def create_backend_obj( affine: affine matrix of the data array. This is used to compute `spacing`, `direction` and `origin`. dtype: output data type. kwargs: keyword arguments. Current `itk.GetImageFromArray` will read ``ttype`` from this dictionary. + + see also: + + - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L389 """ data_array = super().create_backend_obj(data_array) _is_vec = channel_dim is not None diff --git a/monai/data/utils.py b/monai/data/utils.py index 0b5b6faf7f..f9fa54dde0 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -52,41 +52,46 @@ __all__ = [ - "get_random_patch", - "iter_patch_slices", - "dense_patch_slices", - "iter_patch", - "get_valid_patch_size", - "list_data_collate", - "worker_init_fn", - "set_rnd", + "AFFINE_TOL", + "SUPPORTED_PICKLE_MOD", "affine_to_spacing", - "correct_nifti_header_if_necessary", - "rectify_header_sform_qform", - "zoom_affine", + "compute_importance_map", "compute_shape_offset", - "to_affine_nd", + "convert_tables_to_dicts", + "correct_nifti_header_if_necessary", "create_file_basename", - "compute_importance_map", + "decollate_batch", + "dense_patch_slices", + "get_random_patch", + "get_valid_patch_size", "is_supported_format", + "iter_patch", + "iter_patch_slices", + "json_hashing", + "list_data_collate", + "no_collation", + "orientation_ras_lps", + "pad_list_data_collate", "partition_dataset", "partition_dataset_classes", + "pickle_hashing", + "rectify_header_sform_qform", + "reorient_spatial_axes", "resample_datalist", "select_cross_validation_folds", - "json_hashing", - "pickle_hashing", + "set_rnd", "sorted_dict", - "decollate_batch", - "pad_list_data_collate", - "no_collation", - "convert_tables_to_dicts", - "SUPPORTED_PICKLE_MOD", - "reorient_spatial_axes", + "to_affine_nd", + "worker_init_fn", + "zoom_affine", ] # module to be used by `torch.save` SUPPORTED_PICKLE_MOD = {"pickle": pickle} +# tolerance for affine matrix computation +AFFINE_TOL = 1e-3 + def get_random_patch( dims: Sequence[int], patch_size: Sequence[int], rand_state: Optional[np.random.RandomState] = None @@ -718,7 +723,7 @@ def compute_shape_offset( k = 0 for i in range(corners.shape[1]): min_corner = np.min(mat @ corners[:-1, :] - mat @ corners[:-1, i : i + 1], 1) - if np.allclose(min_corner, 0.0, rtol=1e-3): + if np.allclose(min_corner, 0.0, rtol=AFFINE_TOL): k = i break offset = corners[:-1, k] diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index cb2b008962..e33a4ee314 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -20,7 +20,7 @@ from monai.config import USE_COMPILED, DtypeLike from monai.config.type_definitions import NdarrayOrTensor -from monai.data.utils import compute_shape_offset, reorient_spatial_axes, to_affine_nd, zoom_affine +from monai.data.utils import AFFINE_TOL, compute_shape_offset, reorient_spatial_axes, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.networks.utils import meshgrid_ij, normalize_transform from monai.transforms.croppad.array import CenterSpatialCrop, Pad @@ -54,8 +54,6 @@ from monai.utils.module import look_up_option from monai.utils.type_conversion import convert_data_type, convert_to_dst_type -AFFINE_TOL = 1e-3 - nib, has_nib = optional_import("nibabel") __all__ = [ From f49a0924affe432a80b817430e0f1af8a30c604e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 5 Feb 2022 07:49:46 +0800 Subject: [PATCH 31/33] 3765 Enhance `create_multigpu_supervised_XXX` for distributed (#3768) * [DLMED] add check for devices Signed-off-by: Nic Ma * [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/engines/multi_gpu_supervised_trainer.py | 14 ++++-- tests/min_tests.py | 1 + tests/test_parallel_execution_dist.py | 45 +++++++++++++++++++ 3 files changed, 56 insertions(+), 4 deletions(-) create mode 100644 tests/test_parallel_execution_dist.py diff --git a/monai/engines/multi_gpu_supervised_trainer.py b/monai/engines/multi_gpu_supervised_trainer.py index 7c59b670b7..0433617649 100644 --- a/monai/engines/multi_gpu_supervised_trainer.py +++ b/monai/engines/multi_gpu_supervised_trainer.py @@ -74,8 +74,8 @@ def create_multigpu_supervised_trainer( tuple of tensors `(batch_x, batch_y)`. output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`. - distributed: whether convert model to `DistributedDataParallel`, if have multiple devices, use - the first device as output device. + distributed: whether convert model to `DistributedDataParallel`, if `True`, `devices` must contain + only 1 GPU or CPU for current distributed rank. Returns: Engine: a trainer engine with supervised update function. @@ -87,6 +87,8 @@ def create_multigpu_supervised_trainer( devices_ = get_devices_spec(devices) if distributed: + if len(devices_) > 1: + raise ValueError(f"for distributed training, `devices` must contain only 1 GPU or CPU, but got {devices_}.") net = DistributedDataParallel(net, device_ids=devices_) elif len(devices_) > 1: net = DataParallel(net) @@ -122,8 +124,8 @@ def create_multigpu_supervised_evaluator( output_transform: function that receives 'x', 'y', 'y_pred' and returns value to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits output expected by metrics. If you change it you should use `output_transform` in metrics. - distributed: whether convert model to `DistributedDataParallel`, if have multiple devices, use - the first device as output device. + distributed: whether convert model to `DistributedDataParallel`, if `True`, `devices` must contain + only 1 GPU or CPU for current distributed rank. Note: `engine.state.output` for this engine is defined by `output_transform` parameter and is @@ -137,6 +139,10 @@ def create_multigpu_supervised_evaluator( if distributed: net = DistributedDataParallel(net, device_ids=devices_) + if len(devices_) > 1: + raise ValueError( + f"for distributed evaluation, `devices` must contain only 1 GPU or CPU, but got {devices_}." + ) elif len(devices_) > 1: net = DataParallel(net) diff --git a/tests/min_tests.py b/tests/min_tests.py index 783ab370c1..090167c4b1 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -155,6 +155,7 @@ def run_testsuit(): "test_zoom_affine", "test_zoomd", "test_prepare_batch_default_dist", + "test_parallel_execution_dist", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_parallel_execution_dist.py b/tests/test_parallel_execution_dist.py new file mode 100644 index 0000000000..f067b71d14 --- /dev/null +++ b/tests/test_parallel_execution_dist.py @@ -0,0 +1,45 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +import torch.distributed as dist + +from monai.engines import create_multigpu_supervised_trainer +from tests.utils import DistCall, DistTestCase, skip_if_no_cuda + + +def fake_loss(y_pred, y): + return (y_pred[0] + y).sum() + + +def fake_data_stream(): + while True: + yield torch.rand((10, 1, 64, 64)), torch.rand((10, 1, 64, 64)) + + +class DistributedTestParallelExecution(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) + @skip_if_no_cuda + def test_distributed(self): + device = torch.device(f"cuda:{dist.get_rank()}") + net = torch.nn.Conv2d(1, 1, 3, padding=1).to(device) + opt = torch.optim.Adam(net.parameters(), 1e-3) + + trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [device], distributed=True) + trainer.run(fake_data_stream(), 2, 2) + # assert the trainer output is loss value + self.assertTrue(isinstance(trainer.state.output, float)) + + +if __name__ == "__main__": + unittest.main() From 9a5fd2e0887945371c827f6209c71b01335eff90 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 5 Feb 2022 20:42:08 +0000 Subject: [PATCH 32/33] update to support dynamic spatial_size Signed-off-by: Wenqi Li --- monai/data/image_writer.py | 31 ++++++++++++++++++++++++++++--- monai/transforms/spatial/array.py | 23 +++++++++++++---------- tests/test_itk_writer.py | 12 ++++++++++++ 3 files changed, 53 insertions(+), 13 deletions(-) diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index e38f790e1a..c8ab8e344c 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -247,17 +247,42 @@ def get_meta_info(cls, metadata: Optional[Mapping] = None): class ITKWriter(ImageWriter): """ Write data and metadata into files on disk using ITK-python. + + .. code-block:: python + + import numpy as np + from monai.data import ITKWriter + + np_data = np.arange(48).reshape(3, 4, 4) + + # write as 3d spatial image no channel + writer = ITKWriter(output_dtype=np.float32) + writer.set_data_array(np_data, channel_dim=None) + # optionally set metadata affine + writer.set_metadata({"affine": np.eye(4), "original_affine": -1 * np.eye(4)}) + writer.write("test1.nii.gz") + + # write as 2d image, channel-first + writer = ITKWriter(output_dtype=np.uint8) + writer.set_data_array(np_data, channel_dim=0) + writer.set_metadata({"spatial_shape": (5, 5)}) + writer.write("test1.png") + """ - def __init__(self, output_dtype: DtypeLike = np.float32, affine=None, channel_dim=0, **kwargs): + def __init__(self, output_dtype: DtypeLike = np.float32, **kwargs): """ Args: output_dtype: output data type. kwargs: keyword arguments passed to ``ImageWriter``. - The constructor will create ``self.output_dtype``, ``self.affine``, ``self.channel_dim`` internally. + The constructor will create ``self.output_dtype`` internally. + ``affine`` and ``channel_dim`` are initialized as instance members (default ``None``, ``0``): + + - user-specified ``affine`` should be set in ``set_metadata``, + - user-specified ``channel_dim`` should be set in ``set_data_array``. """ - super().__init__(output_dtype=output_dtype, affine=affine, channel_dim=channel_dim, **kwargs) + super().__init__(output_dtype=output_dtype, affine=None, channel_dim=0, **kwargs) def set_data_array(self, data_array, channel_dim: Optional[int] = 0, squeeze_end_dims: bool = True, **kwargs): """ diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index e33a4ee314..38cfeb00c9 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -173,18 +173,27 @@ def __call__( if src_affine is None: src_affine = np.eye(4, dtype=np.float64) spatial_rank = min(len(img.shape) - 1, src_affine.shape[0] - 1, 3) + if spatial_size is not -1 and spatial_size is not None: + spatial_rank = min(len(ensure_tuple(spatial_size)), 3) # infer spatial rank based on spatial_size src_affine = to_affine_nd(spatial_rank, src_affine) dst_affine = to_affine_nd(spatial_rank, dst_affine) if dst_affine is not None else src_affine dst_affine, *_ = convert_to_dst_type(dst_affine, dst_affine, dtype=torch.float32) - if allclose(src_affine, dst_affine, atol=AFFINE_TOL): + in_spatial_size = np.asarray(img.shape[1 : spatial_rank + 1]) + if spatial_size is -1: # using the input spatial size + spatial_size = in_spatial_size + elif spatial_size is None: # auto spatial size + spatial_size, _ = compute_shape_offset(in_spatial_size, src_affine, dst_affine) # type: ignore + spatial_size = np.asarray(fall_back_tuple(ensure_tuple(spatial_size)[:spatial_rank], in_spatial_size)) + + if allclose(src_affine, dst_affine, atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size): # no significant change, return original image output_data, *_ = convert_to_dst_type(img, img, dtype=torch.float32) return output_data, dst_affine if has_nib and isinstance(img, np.ndarray): spatial_ornt, dst_r = reorient_spatial_axes(img.shape[1 : spatial_rank + 1], src_affine, dst_affine) - if allclose(dst_r, dst_affine, atol=AFFINE_TOL): + if allclose(dst_r, dst_affine, atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size): # simple reorientation achieves the desired affine spatial_ornt[:, 0] += 1 spatial_ornt = np.concatenate([np.array([[0, 1]]), spatial_ornt]) @@ -206,18 +215,12 @@ def __call__( raise ValueError(f"src affine is not invertible: {src_affine}") from e xform = to_affine_nd(spatial_rank, xform) # no resampling if it's identity transform - if allclose(xform, np.diag(np.ones(len(xform))), atol=AFFINE_TOL): + if allclose(xform, np.diag(np.ones(len(xform))), atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size): output_data, *_ = convert_to_dst_type(img, img, dtype=torch.float32) return output_data, dst_affine _dtype = dtype or self.dtype or img.dtype - spatial_size = ensure_tuple(spatial_size) - in_spatial_size = list(img.shape[1 : spatial_rank + 1]) - if spatial_size[0] == -1: # if the spatial_size == -1 - spatial_size = in_spatial_size - elif spatial_size[0] is None: - spatial_size, _ = compute_shape_offset(in_spatial_size, src_affine, dst_affine) # type: ignore - spatial_size = spatial_size[:spatial_rank] + in_spatial_size = in_spatial_size.tolist() chns, additional_dims = img.shape[0], img.shape[spatial_rank + 1 :] # beyond three spatial dims # resample img_ = convert_data_type(img, torch.Tensor, dtype=_dtype)[0] diff --git a/tests/test_itk_writer.py b/tests/test_itk_writer.py index 96231c9b69..163fead76e 100644 --- a/tests/test_itk_writer.py +++ b/tests/test_itk_writer.py @@ -38,6 +38,18 @@ def test_channel_shape(self): s.pop(c) np.testing.assert_allclose(itk.size(itk_obj), s) + def test_rgb(self): + with tempfile.TemporaryDirectory() as tempdir: + fname = os.path.join(tempdir, "testing.png") + writer = ITKWriter(output_dtype=np.uint8) + writer.set_data_array(np.arange(48).reshape(3, 4, 4), channel_dim=0) + writer.set_metadata({"spatial_shape": (5, 5)}) + writer.write(fname) + + output = np.asarray(itk.imread(fname)) + np.testing.assert_allclose(output.shape, (5, 5, 3)) + np.testing.assert_allclose(output[1, 1], (5, 5, 4)) + if __name__ == "__main__": unittest.main() From 9b75f60347691c9a465ea182e8018e20487c67db Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 7 Feb 2022 08:54:46 +0000 Subject: [PATCH 33/33] update based on comments Signed-off-by: Wenqi Li --- monai/data/__init__.py | 2 +- monai/data/image_writer.py | 4 ++-- monai/data/utils.py | 18 +++++++++--------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index f03410b5ba..58d66099be 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -35,7 +35,7 @@ from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader -from .image_writer import ImageWriter, ITKWriter +from .image_writer import ImageWriter, ITKWriter, logger from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index c8ab8e344c..40d0dc6b7e 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -28,7 +28,7 @@ else: itk, _ = optional_import("itk", allow_namespace_pkg=True) -__all__ = ["ImageWriter", "ITKWriter"] +__all__ = ["ImageWriter", "ITKWriter", "logger"] class ImageWriter: @@ -386,7 +386,7 @@ def create_backend_obj( affine = np.eye(d + 1, dtype=np.float64) _affine = convert_data_type(affine, np.ndarray)[0] _affine = orientation_ras_lps(to_affine_nd(d, _affine)) - spacing = affine_to_spacing(_affine, dims=d) + spacing = affine_to_spacing(_affine, r=d) _direction: np.ndarray = np.diag(1 / spacing) _direction = _affine[:d, :d] @ _direction # type: ignore itk_obj.SetSpacing(spacing.tolist()) diff --git a/monai/data/utils.py b/monai/data/utils.py index f9fa54dde0..495daf15e2 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -553,20 +553,20 @@ def set_rnd(obj, seed: int) -> int: return seed -def affine_to_spacing(affine: NdarrayTensor, dims: int = 3, dtype=float, suppress_zeros: bool = True) -> NdarrayTensor: +def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_zeros: bool = True) -> NdarrayTensor: """ Computing the current spacing from the affine matrix. Args: affine: a d x d affine matrix. - dims: indexing of the spatial dimensions `affine[:dims, :dims]`. + r: indexing based on the spatial rank, spacing is computed from `affine[:r, :r]`. dtype: data type of the output. suppress_zeros: whether to surpress the zeros with ones. Returns: - a `dims` dimensional vector of spacing. + an `r` dimensional vector of spacing. """ - _affine, *_ = convert_to_dst_type(affine[:dims, :dims], dst=affine, dtype=dtype) + _affine, *_ = convert_to_dst_type(affine[:r, :r], dst=affine, dtype=dtype) if isinstance(_affine, torch.Tensor): spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0)) else: @@ -592,7 +592,7 @@ def correct_nifti_header_if_necessary(img_nii): return img_nii # do nothing for high-dimensional array # check that affine matches zooms pixdim = np.asarray(img_nii.header.get_zooms())[:dim] - norm_affine = affine_to_spacing(img_nii.affine, dims=dim) + norm_affine = affine_to_spacing(img_nii.affine, r=dim) if np.allclose(pixdim, norm_affine): return img_nii if hasattr(img_nii, "get_sform"): @@ -613,8 +613,8 @@ def rectify_header_sform_qform(img_nii): d = img_nii.header["dim"][0] pixdim = np.asarray(img_nii.header.get_zooms())[:d] sform, qform = img_nii.get_sform(), img_nii.get_qform() - norm_sform = affine_to_spacing(sform, dims=d) - norm_qform = affine_to_spacing(qform, dims=d) + norm_sform = affine_to_spacing(sform, r=d) + norm_qform = affine_to_spacing(qform, r=d) sform_mismatch = not np.allclose(norm_sform, pixdim) qform_mismatch = not np.allclose(norm_qform, pixdim) @@ -631,7 +631,7 @@ def rectify_header_sform_qform(img_nii): img_nii.set_qform(img_nii.get_sform()) return img_nii - norm = affine_to_spacing(img_nii.affine, dims=d) + norm = affine_to_spacing(img_nii.affine, r=d) warnings.warn(f"Modifying image pixdim from {pixdim} to {norm}") img_nii.header.set_zooms(norm) @@ -671,7 +671,7 @@ def zoom_affine(affine: np.ndarray, scale: Union[np.ndarray, Sequence[float]], d d = len(affine) - 1 # compute original pixdim - norm = affine_to_spacing(affine, dims=d) + norm = affine_to_spacing(affine, r=d) if len(scale_np) < d: # defaults based on affine scale_np = np.append(scale_np, norm[len(scale_np) :]) scale_np = scale_np[:d]