diff --git a/docs/source/data.rst b/docs/source/data.rst index 1e6b535b12..1bb9bc8869 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -150,6 +150,19 @@ WSIReader .. autoclass:: WSIReader :members: +Image writer +------------ + +ImageWriter +~~~~~~~~~~~ +.. autoclass:: ImageWriter + :members: + +ITKWriter +~~~~~~~~~ +.. autoclass:: ITKWriter + :members: + Nifti format handling --------------------- diff --git a/monai/data/__init__.py b/monai/data/__init__.py index bd49f40273..58d66099be 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -35,6 +35,7 @@ from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader +from .image_writer import ImageWriter, ITKWriter, logger from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti @@ -60,11 +61,13 @@ iter_patch_slices, json_hashing, list_data_collate, + orientation_ras_lps, pad_list_data_collate, partition_dataset, partition_dataset_classes, pickle_hashing, rectify_header_sform_qform, + reorient_spatial_axes, resample_datalist, select_cross_validation_folds, set_rnd, diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py new file mode 100644 index 0000000000..40d0dc6b7e --- /dev/null +++ b/monai/data/image_writer.py @@ -0,0 +1,395 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Mapping, Optional, Sequence, Union + +import numpy as np + +from monai.apps.utils import get_logger +from monai.config import DtypeLike, NdarrayOrTensor, PathLike +from monai.data.utils import affine_to_spacing, ensure_tuple, orientation_ras_lps, to_affine_nd +from monai.transforms.spatial.array import SpatialResample +from monai.transforms.utils_pytorch_numpy_unification import ascontiguousarray, moveaxis +from monai.utils import GridSampleMode, GridSamplePadMode, convert_data_type, optional_import, require_pkg + +DEFAULT_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d - %(message)s" +logger = get_logger(module_name=__name__, fmt=DEFAULT_FMT) + +if TYPE_CHECKING: + import itk # type: ignore +else: + itk, _ = optional_import("itk", allow_namespace_pkg=True) + +__all__ = ["ImageWriter", "ITKWriter", "logger"] + + +class ImageWriter: + """ + The class is a collection of utilities to write images to disk. + + Main aspects to be considered are: + + - dimensionality of the data array, arrangements of spatial dimensions and channel/time dimensions + - ``convert_to_channel_last()`` + - metadata of the current affine and output affine, the data array should be converted accordingly + - ``get_meta_info()`` + - ``resample_if_needed()`` + - data type handling of the output image (as part of ``resample_if_needed()``) + + Subclasses of this class should implement the backend-specific functions: + + - ``set_data_array()`` to set the data array (input must be numpy array or torch tensor) + - this method sets the backend object's data part + - ``set_metadata()`` to set the metadata and output affine + - this method sets the metadata including affine handling and image resampling + - backend-specific data object ``create_backend_obj()`` + - backend-specific writing function ``write()`` + + The primary usage of subclasses of ``ImageWriter`` is: + + .. code-block:: python + + writer = MyWriter() # subclass of ImageWriter + writer.set_data_array(data_array) + writer.set_metadata(meta_dict) + writer.write(filename) + + This creates an image writer object based on ``data_array`` and ``meta_dict`` and write to ``filename``. + + It supports up to three spatial dimensions (with the resampling step supports for both 2D and 3D). + When saving multiple time steps or multiple channels `data_array`, time + and/or modality axes should be the at the `channel_dim`. For example, + the shape of a 2D eight-class and ``channel_dim=0``, the segmentation + probabilities to be saved could be `(8, 64, 64)`; in this case + ``data_array`` will be converted to `(64, 64, 1, 8)` (the third + dimension is reserved as a spatial dimension). + + The ``metadata`` could optionally have the following keys: + + - ``'original_affine'``: for data original affine, it will be the + affine of the output object, defaulting to an identity matrix. + - ``'affine'``: it should specify the current data affine, defaulting to an identity matrix. + - ``'spatial_shape'``: for data output spatial shape. + + When ``metadata`` is specified, the saver will may resample data from the space defined by + `"affine"` to the space defined by `"original_affine"`, for more details, please refer to the + ``resample_if_needed`` method. + """ + + def __init__(self, **kwargs): + """ + The constructor supports adding new instance members. + The current member in the base class is ``self.data_obj``, the subclasses can add more members, + so that necessary meta information can be stored in the object and shared among the class methods. + """ + self.data_obj = None + for k, v in kwargs.items(): + setattr(self, k, v) + + def set_data_array(self, data_array, **kwargs): + raise NotImplementedError(f"Subclasses of {self.__class__.__name__} must implement this method.") + + def set_metadata(self, meta_dict: Optional[Mapping], **options): + raise NotImplementedError(f"Subclasses of {self.__class__.__name__} must implement this method.") + + def write(self, filename: PathLike, verbose: bool = True, **kwargs): + """subclass should implement this method to call the backend-specific writing APIs.""" + if verbose: + logger.info(f"writing: {filename}") + + @classmethod + def create_backend_obj(cls, data_array: NdarrayOrTensor, **kwargs) -> np.ndarray: + """ + Subclass should implement this method to return a backend-specific data representation object. + This method is used by ``cls.write`` and the input ``data_array`` is assumed 'channel-last'. + """ + return convert_data_type(data_array, np.ndarray)[0] # type: ignore + + @classmethod + def resample_if_needed( + cls, + data_array: NdarrayOrTensor, + affine: Optional[NdarrayOrTensor] = None, + target_affine: Optional[NdarrayOrTensor] = None, + output_spatial_shape: Union[Sequence[int], int, None] = None, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: DtypeLike = np.float64, + ): + """ + Convert the ``data_array`` into the coordinate system specified by + ``target_affine``, from the current coordinate definition of ``affine``. + + If the transform between ``affine`` and ``target_affine`` could be + achieved by simply transposing and flipping ``data_array``, no resampling + will happen. Otherwise, this function resamples ``data_array`` using the + transformation computed from ``affine`` and ``target_affine``. + + This function assumes the NIfTI dimension notations. Spatially it + supports up to three dimensions, that is, H, HW, HWD for 1D, 2D, 3D + respectively. When saving multiple time steps or multiple channels, + time and/or modality axes should be appended after the first three + dimensions. For example, shape of 2D eight-class segmentation + probabilities to be saved could be `(64, 64, 1, 8)`. Also, data in + shape `(64, 64, 8)` or `(64, 64, 8, 1)` will be considered as a + single-channel 3D image. The ``convert_to_channel_last`` method can be + used to convert the data to the format described here. + + Note that the shape of the resampled ``data_array`` may subject to some + rounding errors. For example, resampling a 20x20 pixel image from pixel + size (1.5, 1.5)-mm to (3.0, 3.0)-mm space will return a 10x10-pixel + image. However, resampling a 20x20-pixel image from pixel size (2.0, + 2.0)-mm to (3.0, 3.0)-mm space will output a 14x14-pixel image, where + the image shape is rounded from 13.333x13.333 pixels. In this case + ``output_spatial_shape`` could be specified so that this function + writes image data to a designated shape. + + Args: + data_array: input data array to be converted. + affine: the current affine of ``data_array``. Defaults to identity + target_affine: the designated affine of ``data_array``. + The actual output affine might be different from this value due to precision changes. + output_spatial_shape: spatial shape of the output image. + This option is used when resampling is needed. + mode: available options are {``"bilinear"``, ``"nearest"``, ``"bicubic"``}. + This option is used when resampling is needed. + Interpolation mode to calculate output values. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + padding_mode: available options are {``"zeros"``, ``"border"``, ``"reflection"``}. + This option is used when resampling is needed. + Padding mode for outside grid values. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + align_corners: boolean option of ``grid_sample`` to handle the corner convention. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + dtype: data type for resampling computation. Defaults to + ``np.float64`` for best precision. If ``None``, use the data type of input data. + The output data type of this method is always ``np.float32``. + """ + resampler = SpatialResample(mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype) + output_array, target_affine = resampler( + data_array[None], src_affine=affine, dst_affine=target_affine, spatial_size=output_spatial_shape + ) + return output_array[0], target_affine + + @classmethod + def convert_to_channel_last( + cls, + data: NdarrayOrTensor, + channel_dim: Union[None, int, Sequence[int]] = 0, + squeeze_end_dims: bool = True, + spatial_ndim: Optional[int] = 3, + contiguous: bool = False, + ): + """ + Rearrange the data array axes to make the `channel_dim`-th dim the last + dimension and ensure there are ``spatial_ndim`` number of spatial + dimensions. + + When ``squeeze_end_dims`` is ``True``, a postprocessing step will be + applied to remove any trailing singleton dimensions. + + Args: + data: input data to be converted to "channel-last" format. + channel_dim: specifies the channel axes of the data array to move to the last. + ``None`` indicates no channel dimension, a new axis will be appended as the channel dimension. + a sequence of integers indicates multiple non-spatial dimensions. + squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed (after the channel + has been moved to the end). So if input is `(H,W,D,C)` and C==1, then it will be saved as `(H,W,D)`. + If D is also 1, it will be saved as `(H,W)`. If ``False``, image will always be saved as `(H,W,D,C)`. + spatial_ndim: modifying the spatial dims if needed, so that output to have at least + this number of spatial dims. If ``None``, the output will have the same number of + spatial dimensions as the input. + contiguous: if ``True``, the output will be contiguous. + """ + # change data to "channel last" format + if channel_dim is not None: + _chns = ensure_tuple(channel_dim) + data = moveaxis(data, _chns, tuple(range(-len(_chns), 0))) + else: # adds a channel dimension + data = data[..., None] + # To ensure at least ``spatial_ndim`` number of spatial dims + if spatial_ndim: + while len(data.shape) < spatial_ndim + 1: # assuming the data has spatial + channel dims + data = data[..., None, :] + while len(data.shape) > spatial_ndim + 1: + data = data[..., 0, :] + # if desired, remove trailing singleton dimensions + while squeeze_end_dims and data.shape[-1] == 1: + data = np.squeeze(data, -1) + if contiguous: + data = ascontiguousarray(data) + return data + + @classmethod + def get_meta_info(cls, metadata: Optional[Mapping] = None): + """ + Extracts relevant meta information from the metadata object (using ``.get``). + Optional keys are ``"spatial_shape"``, ``"affine"``, ``"original_affine"``. + """ + if not metadata: + metadata = {"original_affine": None, "affine": None, "spatial_shape": None} + original_affine = metadata.get("original_affine") + affine = metadata.get("affine") + spatial_shape = metadata.get("spatial_shape") + return original_affine, affine, spatial_shape + + +@require_pkg(pkg_name="itk") +class ITKWriter(ImageWriter): + """ + Write data and metadata into files on disk using ITK-python. + + .. code-block:: python + + import numpy as np + from monai.data import ITKWriter + + np_data = np.arange(48).reshape(3, 4, 4) + + # write as 3d spatial image no channel + writer = ITKWriter(output_dtype=np.float32) + writer.set_data_array(np_data, channel_dim=None) + # optionally set metadata affine + writer.set_metadata({"affine": np.eye(4), "original_affine": -1 * np.eye(4)}) + writer.write("test1.nii.gz") + + # write as 2d image, channel-first + writer = ITKWriter(output_dtype=np.uint8) + writer.set_data_array(np_data, channel_dim=0) + writer.set_metadata({"spatial_shape": (5, 5)}) + writer.write("test1.png") + + """ + + def __init__(self, output_dtype: DtypeLike = np.float32, **kwargs): + """ + Args: + output_dtype: output data type. + kwargs: keyword arguments passed to ``ImageWriter``. + + The constructor will create ``self.output_dtype`` internally. + ``affine`` and ``channel_dim`` are initialized as instance members (default ``None``, ``0``): + + - user-specified ``affine`` should be set in ``set_metadata``, + - user-specified ``channel_dim`` should be set in ``set_data_array``. + """ + super().__init__(output_dtype=output_dtype, affine=None, channel_dim=0, **kwargs) + + def set_data_array(self, data_array, channel_dim: Optional[int] = 0, squeeze_end_dims: bool = True, **kwargs): + """ + Convert ``data_array`` into 'channel-last' numpy ndarray. + + Args: + data_array: input data array with the channel dimension specified by ``channel_dim``. + channel_dim: channel dimension of the data array. Defaults to 0. + ``None`` indicates data without any channel dimension. + squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed. + kwargs: keyword arguments passed to ``self.convert_to_channel_last``, + currently support ``spatial_ndim`` and ``contiguous``, defauting to ``3`` and ``False`` respectively. + """ + self.data_obj = self.convert_to_channel_last( + data=data_array, + channel_dim=channel_dim, + squeeze_end_dims=squeeze_end_dims, + spatial_ndim=kwargs.pop("spatial_ndim", 3), + contiguous=kwargs.pop("contiguous", True), + ) + self.channel_dim = channel_dim + + def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = True, **options): + """ + Resample ``self.dataobj`` if needed. This method assumes ``self.data_obj`` is a 'channel-last' ndarray. + + Args: + meta_dict: a metadata dictionary for affine, original affine and spatial shape information. + Optional keys are ``"spatial_shape"``, ``"affine"``, ``"original_affine"``. + resample: if ``True``, the data will be resampled to the original affine (specified in ``meta_dict``). + options: keyword arguments passed to ``self.resample_if_needed``, + currently support ``mode``, ``padding_mode``, ``align_corners``, and ``dtype``, + defaulting to ``bilinear``, ``border``, ``False``, and ``np.float64`` respectively. + """ + original_affine, affine, spatial_shape = self.get_meta_info(meta_dict) + self.data_obj, self.affine = self.resample_if_needed( + data_array=self.data_obj, + affine=affine, + target_affine=original_affine if resample else None, + output_spatial_shape=spatial_shape, + mode=options.pop("mode", GridSampleMode.BILINEAR), + padding_mode=options.pop("padding_mode", GridSamplePadMode.BORDER), + align_corners=options.pop("align_corners", False), + dtype=options.pop("dtype", np.float64), + ) + + def write(self, filename: PathLike, verbose: bool = False, **kwargs): + """ + Create an ITK object from ``self.create_backend_obj(self.obj, ...)`` and call ``itk.imwrite``. + + Args: + filename: filename or PathLike object. + verbose: if ``True``, log the progress. + kwargs: keyword arguments passed to ``itk.imwrite``, + currently support ``compression`` and ``imageio``. + + See also: + + - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L809 + """ + super().write(filename, verbose=verbose) + self.data_obj = self.create_backend_obj( + self.data_obj, channel_dim=self.channel_dim, affine=self.affine, dtype=self.output_dtype, **kwargs # type: ignore + ) + itk.imwrite( + self.data_obj, filename, compression=kwargs.pop("compression", False), imageio=kwargs.pop("imageio", None) + ) + + @classmethod + def create_backend_obj( + cls, + data_array: NdarrayOrTensor, + channel_dim: Optional[int] = 0, + affine: Optional[NdarrayOrTensor] = None, + dtype: DtypeLike = np.float32, + **kwargs, + ): + """ + Create an ITK object from ``data_array``. This method assumes a 'channel-last' ``data_array``. + + Args: + data_array: input data array. + channel_dim: channel dimension of the data array. This is used to create a Vector Image if it is not ``None``. + affine: affine matrix of the data array. This is used to compute `spacing`, `direction` and `origin`. + dtype: output data type. + kwargs: keyword arguments. Current `itk.GetImageFromArray` will read ``ttype`` from this dictionary. + + see also: + + - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L389 + """ + data_array = super().create_backend_obj(data_array) + _is_vec = channel_dim is not None + if _is_vec: + data_array = np.moveaxis(data_array, -1, 0) # from channel last to channel first + data_array = data_array.T.astype(dtype, copy=True, order="C") + itk_obj = itk.GetImageFromArray(data_array, is_vector=_is_vec, ttype=kwargs.pop("ttype", None)) + + d = len(itk.size(itk_obj)) + if affine is None: + affine = np.eye(d + 1, dtype=np.float64) + _affine = convert_data_type(affine, np.ndarray)[0] + _affine = orientation_ras_lps(to_affine_nd(d, _affine)) + spacing = affine_to_spacing(_affine, r=d) + _direction: np.ndarray = np.diag(1 / spacing) + _direction = _affine[:d, :d] @ _direction # type: ignore + itk_obj.SetSpacing(spacing.tolist()) + itk_obj.SetOrigin(_affine[:d, -1].tolist()) + itk_obj.SetDirection(itk.GetMatrixFromArray(_direction)) + return itk_obj diff --git a/monai/data/utils.py b/monai/data/utils.py index 779671e793..495daf15e2 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -52,40 +52,46 @@ __all__ = [ - "get_random_patch", - "iter_patch_slices", - "dense_patch_slices", - "iter_patch", - "get_valid_patch_size", - "list_data_collate", - "worker_init_fn", - "set_rnd", - "correct_nifti_header_if_necessary", - "rectify_header_sform_qform", - "zoom_affine", + "AFFINE_TOL", + "SUPPORTED_PICKLE_MOD", + "affine_to_spacing", + "compute_importance_map", "compute_shape_offset", - "to_affine_nd", + "convert_tables_to_dicts", + "correct_nifti_header_if_necessary", "create_file_basename", - "compute_importance_map", + "decollate_batch", + "dense_patch_slices", + "get_random_patch", + "get_valid_patch_size", "is_supported_format", + "iter_patch", + "iter_patch_slices", + "json_hashing", + "list_data_collate", + "no_collation", + "orientation_ras_lps", + "pad_list_data_collate", "partition_dataset", "partition_dataset_classes", + "pickle_hashing", + "rectify_header_sform_qform", + "reorient_spatial_axes", "resample_datalist", "select_cross_validation_folds", - "json_hashing", - "pickle_hashing", + "set_rnd", "sorted_dict", - "decollate_batch", - "pad_list_data_collate", - "no_collation", - "convert_tables_to_dicts", - "SUPPORTED_PICKLE_MOD", - "reorient_spatial_axes", + "to_affine_nd", + "worker_init_fn", + "zoom_affine", ] # module to be used by `torch.save` SUPPORTED_PICKLE_MOD = {"pickle": pickle} +# tolerance for affine matrix computation +AFFINE_TOL = 1e-3 + def get_random_patch( dims: Sequence[int], patch_size: Sequence[int], rand_state: Optional[np.random.RandomState] = None @@ -547,6 +553,30 @@ def set_rnd(obj, seed: int) -> int: return seed +def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_zeros: bool = True) -> NdarrayTensor: + """ + Computing the current spacing from the affine matrix. + + Args: + affine: a d x d affine matrix. + r: indexing based on the spatial rank, spacing is computed from `affine[:r, :r]`. + dtype: data type of the output. + suppress_zeros: whether to surpress the zeros with ones. + + Returns: + an `r` dimensional vector of spacing. + """ + _affine, *_ = convert_to_dst_type(affine[:r, :r], dst=affine, dtype=dtype) + if isinstance(_affine, torch.Tensor): + spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0)) + else: + spacing = np.sqrt(np.sum(_affine * _affine, axis=0)) + if suppress_zeros: + spacing[spacing == 0] = 1.0 + spacing_, *_ = convert_to_dst_type(spacing, dst=affine, dtype=dtype) + return spacing_ + + def correct_nifti_header_if_necessary(img_nii): """ Check nifti object header's format, update the header if needed. @@ -562,7 +592,7 @@ def correct_nifti_header_if_necessary(img_nii): return img_nii # do nothing for high-dimensional array # check that affine matches zooms pixdim = np.asarray(img_nii.header.get_zooms())[:dim] - norm_affine = np.sqrt(np.sum(np.square(img_nii.affine[:dim, :dim]), 0)) + norm_affine = affine_to_spacing(img_nii.affine, r=dim) if np.allclose(pixdim, norm_affine): return img_nii if hasattr(img_nii, "get_sform"): @@ -583,8 +613,8 @@ def rectify_header_sform_qform(img_nii): d = img_nii.header["dim"][0] pixdim = np.asarray(img_nii.header.get_zooms())[:d] sform, qform = img_nii.get_sform(), img_nii.get_qform() - norm_sform = np.sqrt(np.sum(np.square(sform[:d, :d]), 0)) - norm_qform = np.sqrt(np.sum(np.square(qform[:d, :d]), 0)) + norm_sform = affine_to_spacing(sform, r=d) + norm_qform = affine_to_spacing(qform, r=d) sform_mismatch = not np.allclose(norm_sform, pixdim) qform_mismatch = not np.allclose(norm_qform, pixdim) @@ -601,7 +631,7 @@ def rectify_header_sform_qform(img_nii): img_nii.set_qform(img_nii.get_sform()) return img_nii - norm = np.sqrt(np.sum(np.square(img_nii.affine[:d, :d]), 0)) + norm = affine_to_spacing(img_nii.affine, r=d) warnings.warn(f"Modifying image pixdim from {pixdim} to {norm}") img_nii.header.set_zooms(norm) @@ -641,7 +671,7 @@ def zoom_affine(affine: np.ndarray, scale: Union[np.ndarray, Sequence[float]], d d = len(affine) - 1 # compute original pixdim - norm = np.sqrt(np.sum(np.square(affine), 0))[:-1] + norm = affine_to_spacing(affine, r=d) if len(scale_np) < d: # defaults based on affine scale_np = np.append(scale_np, norm[len(scale_np) :]) scale_np = scale_np[:d] @@ -693,7 +723,7 @@ def compute_shape_offset( k = 0 for i in range(corners.shape[1]): min_corner = np.min(mat @ corners[:-1, :] - mat @ corners[:-1, i : i + 1], 1) - if np.allclose(min_corner, 0.0, rtol=1e-3): + if np.allclose(min_corner, 0.0, rtol=AFFINE_TOL): k = i break offset = corners[:-1, k] @@ -1259,3 +1289,19 @@ def convert_tables_to_dicts( data = [dict(d, **{k: v[i] for k, v in groups.items()}) for i, d in enumerate(data)] return data + + +def orientation_ras_lps(affine: NdarrayTensor) -> NdarrayTensor: + """ + Convert the ``affine`` between the `RAS` and `LPS` orientation + by flipping the first two spatial dimensions. + + Args: + affine: a 2D affine matrix. + """ + sr = max(affine.shape[0] - 1, 1) # spatial rank is at least 1 + flip_d = [[-1, 1], [-1, -1, 1], [-1, -1, 1, 1]] + flip_diag = flip_d[min(sr - 1, 2)] + [1] * (sr - 3) + if isinstance(affine, torch.Tensor): + return torch.diag(torch.as_tensor(flip_diag).to(affine)) @ affine # type: ignore + return np.diag(flip_diag).astype(affine.dtype) @ affine # type: ignore diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 7bd5990b3b..38cfeb00c9 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -20,7 +20,7 @@ from monai.config import USE_COMPILED, DtypeLike from monai.config.type_definitions import NdarrayOrTensor -from monai.data.utils import compute_shape_offset, reorient_spatial_axes, to_affine_nd, zoom_affine +from monai.data.utils import AFFINE_TOL, compute_shape_offset, reorient_spatial_axes, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.networks.utils import meshgrid_ij, normalize_transform from monai.transforms.croppad.array import CenterSpatialCrop, Pad @@ -54,8 +54,6 @@ from monai.utils.module import look_up_option from monai.utils.type_conversion import convert_data_type, convert_to_dst_type -AFFINE_TOL = 1e-3 - nib, has_nib = optional_import("nibabel") __all__ = [ @@ -132,7 +130,7 @@ def __call__( img: NdarrayOrTensor, src_affine: Optional[NdarrayOrTensor] = None, dst_affine: Optional[NdarrayOrTensor] = None, - spatial_size: Optional[Union[Sequence[int], int]] = None, + spatial_size: Optional[Union[Sequence[int], np.ndarray, int]] = None, mode: Union[GridSampleMode, str, None] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str, None] = GridSamplePadMode.BORDER, align_corners: Optional[bool] = False, @@ -175,18 +173,27 @@ def __call__( if src_affine is None: src_affine = np.eye(4, dtype=np.float64) spatial_rank = min(len(img.shape) - 1, src_affine.shape[0] - 1, 3) + if spatial_size is not -1 and spatial_size is not None: + spatial_rank = min(len(ensure_tuple(spatial_size)), 3) # infer spatial rank based on spatial_size src_affine = to_affine_nd(spatial_rank, src_affine) dst_affine = to_affine_nd(spatial_rank, dst_affine) if dst_affine is not None else src_affine dst_affine, *_ = convert_to_dst_type(dst_affine, dst_affine, dtype=torch.float32) - if allclose(src_affine, dst_affine, atol=AFFINE_TOL): + in_spatial_size = np.asarray(img.shape[1 : spatial_rank + 1]) + if spatial_size is -1: # using the input spatial size + spatial_size = in_spatial_size + elif spatial_size is None: # auto spatial size + spatial_size, _ = compute_shape_offset(in_spatial_size, src_affine, dst_affine) # type: ignore + spatial_size = np.asarray(fall_back_tuple(ensure_tuple(spatial_size)[:spatial_rank], in_spatial_size)) + + if allclose(src_affine, dst_affine, atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size): # no significant change, return original image output_data, *_ = convert_to_dst_type(img, img, dtype=torch.float32) return output_data, dst_affine if has_nib and isinstance(img, np.ndarray): spatial_ornt, dst_r = reorient_spatial_axes(img.shape[1 : spatial_rank + 1], src_affine, dst_affine) - if allclose(dst_r, dst_affine, atol=AFFINE_TOL): + if allclose(dst_r, dst_affine, atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size): # simple reorientation achieves the desired affine spatial_ornt[:, 0] += 1 spatial_ornt = np.concatenate([np.array([[0, 1]]), spatial_ornt]) @@ -208,18 +215,12 @@ def __call__( raise ValueError(f"src affine is not invertible: {src_affine}") from e xform = to_affine_nd(spatial_rank, xform) # no resampling if it's identity transform - if allclose(xform, np.diag(np.ones(len(xform))), atol=AFFINE_TOL): + if allclose(xform, np.diag(np.ones(len(xform))), atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size): output_data, *_ = convert_to_dst_type(img, img, dtype=torch.float32) return output_data, dst_affine _dtype = dtype or self.dtype or img.dtype - spatial_size = ensure_tuple(spatial_size) - in_spatial_size = list(img.shape[1 : spatial_rank + 1]) - if spatial_size[0] == -1: # if the spatial_size == -1 - spatial_size = in_spatial_size - elif spatial_size[0] is None: - spatial_size, _ = compute_shape_offset(in_spatial_size, src_affine, dst_affine) # type: ignore - spatial_size = spatial_size[:spatial_rank] + in_spatial_size = in_spatial_size.tolist() chns, additional_dims = img.shape[0], img.shape[spatial_rank + 1 :] # beyond three spatial dims # resample img_ = convert_data_type(img, torch.Tensor, dtype=_dtype)[0] @@ -270,7 +271,7 @@ class Spacing(Transform): def __init__( self, - pixdim: Union[Sequence[float], float], + pixdim: Union[Sequence[float], float, np.ndarray], diagonal: bool = False, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, @@ -334,7 +335,7 @@ def __call__( padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, dtype: DtypeLike = None, - output_spatial_shape: Optional[Union[Sequence[int], int]] = None, + output_spatial_shape: Optional[Union[Sequence[int], np.ndarray, int]] = None, ) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor]]: """ Args: diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index eabf309567..aff2c97a63 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -24,6 +24,7 @@ from monai.config import DtypeLike, KeysCollection from monai.config.type_definitions import NdarrayOrTensor +from monai.data.utils import affine_to_spacing from monai.networks.layers import AffineTransform from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad @@ -435,7 +436,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] orig_size = transform[TraceKeys.ORIG_SIZE] - orig_pixdim = np.sqrt(np.sum(np.square(old_affine), 0))[:-1] + orig_pixdim = affine_to_spacing(old_affine, -1) inverse_transform = Spacing(orig_pixdim, diagonal=self.spacing_transform.diagonal) # Apply inverse d[key], _, new_affine = inverse_transform( diff --git a/tests/test_itk_writer.py b/tests/test_itk_writer.py new file mode 100644 index 0000000000..163fead76e --- /dev/null +++ b/tests/test_itk_writer.py @@ -0,0 +1,55 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np +import torch + +from monai.data import ITKWriter +from monai.utils import optional_import + +itk, has_itk = optional_import("itk") +nib, has_nibabel = optional_import("nibabel") + + +@unittest.skipUnless(has_itk, "Requires `itk` package.") +class TestITKWriter(unittest.TestCase): + def test_channel_shape(self): + with tempfile.TemporaryDirectory() as tempdir: + for c in (0, 1, 2, 3): + fname = os.path.join(tempdir, f"testing{c}.nii") + itk_writer = ITKWriter() + itk_writer.set_data_array(torch.zeros(1, 2, 3, 4), channel_dim=c, squeeze_end_dims=False) + itk_writer.set_metadata({}) + itk_writer.write(fname) + itk_obj = itk.imread(fname) + s = [1, 2, 3, 4] + s.pop(c) + np.testing.assert_allclose(itk.size(itk_obj), s) + + def test_rgb(self): + with tempfile.TemporaryDirectory() as tempdir: + fname = os.path.join(tempdir, "testing.png") + writer = ITKWriter(output_dtype=np.uint8) + writer.set_data_array(np.arange(48).reshape(3, 4, 4), channel_dim=0) + writer.set_metadata({"spatial_shape": (5, 5)}) + writer.write(fname) + + output = np.asarray(itk.imread(fname)) + np.testing.assert_allclose(output.shape, (5, 5, 3)) + np.testing.assert_allclose(output[1, 1], (5, 5, 4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_ori_ras_lps.py b/tests/test_ori_ras_lps.py new file mode 100644 index 0000000000..4ed223bf5b --- /dev/null +++ b/tests/test_ori_ras_lps.py @@ -0,0 +1,46 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.data.utils import orientation_ras_lps +from tests.utils import TEST_NDARRAYS, assert_allclose + +TEST_CASES_AFFINE = [] +for p in TEST_NDARRAYS: + case_1d = p([[1.0, 0.0], [1.0, 1.0]]), p([[-1.0, 0.0], [1.0, 1.0]]) + TEST_CASES_AFFINE.append(case_1d) + case_2d_1 = p([[1.0, 0.0, 1.0], [1.0, 1.0, 1.0]]), p([[-1.0, 0.0, -1.0], [1.0, 1.0, 1.0]]) + TEST_CASES_AFFINE.append(case_2d_1) + case_2d_2 = p([[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]), p( + [[-1.0, 0.0, -1.0], [0.0, -1.0, -1.0], [1.0, 1.0, 1.0]] + ) + TEST_CASES_AFFINE.append(case_2d_2) + case_3d = p([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 2.0], [1.0, 1.0, 1.0, 3.0]]), p( + [[-1.0, 0.0, -1.0, -1.0], [0.0, -1.0, -1.0, -2.0], [1.0, 1.0, 1.0, 3.0]] + ) + TEST_CASES_AFFINE.append(case_3d) + case_4d = p(np.ones((5, 5))), p([[-1] * 5, [-1] * 5, [1] * 5, [1] * 5, [1] * 5]) + TEST_CASES_AFFINE.append(case_4d) + + +class TestITKWriter(unittest.TestCase): + @parameterized.expand(TEST_CASES_AFFINE) + def test_ras_to_lps(self, param, expected): + assert_allclose(orientation_ras_lps(param), expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_spacing.py b/tests/test_spacing.py index 0dd10a54f3..80df981b73 100644 --- a/tests/test_spacing.py +++ b/tests/test_spacing.py @@ -15,6 +15,7 @@ import torch from parameterized import parameterized +from monai.data.utils import affine_to_spacing from monai.transforms import Spacing from monai.utils import ensure_tuple, fall_back_tuple from tests.utils import TEST_NDARRAYS @@ -218,7 +219,7 @@ def test_spacing(self, in_type, init_param, img, data_param, expected_output): init_param["pixdim"] = [init_param["pixdim"]] * sr init_pixdim = ensure_tuple(init_param["pixdim"]) init_pixdim = init_param["pixdim"][:sr] - norm = np.sqrt(np.sum(np.square(new_affine), axis=0))[:sr] + norm = affine_to_spacing(new_affine, sr) np.testing.assert_allclose(fall_back_tuple(init_pixdim, norm), norm)