diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 0fb96e88a1..f7327aa07b 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -293,13 +293,11 @@ def __call__( # type: ignore dtype = dtype or self.dtype src_affine = src_meta.get("affine") dst_affine = dst_meta.get("affine") - ndim = len(img.shape[1:]) - spatial_size = dst_meta.get("dim", [])[1 : ndim + 2] img, updated_affine = super().__call__( img=img, src_affine=src_affine, dst_affine=dst_affine, - spatial_size=spatial_size, + spatial_size=dst_meta.get("spatial_shape"), mode=mode, padding_mode=padding_mode, align_corners=align_corners, @@ -307,10 +305,6 @@ def __call__( # type: ignore ) dst_meta = deepcopy(dst_meta) dst_meta["affine"] = updated_affine - if "dim" in dst_meta: - dst_meta["dim"] = src_meta.get("dim", []) - if "pixdim" in dst_meta: - dst_meta["pixdim"] = src_meta.get("pixdim", []) return img, dst_meta diff --git a/tests/test_resample_to_match.py b/tests/test_resample_to_match.py index cf8c98d588..b0a5027071 100644 --- a/tests/test_resample_to_match.py +++ b/tests/test_resample_to_match.py @@ -9,13 +9,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +import itertools import os import tempfile import unittest +import nibabel as nib +import numpy as np +from parameterized import parameterized + +from monai.data.image_reader import ITKReader, NibabelReader +from monai.data.image_writer import ITKWriter, NibabelWriter from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ResampleToMatch, SaveImaged from tests.utils import assert_allclose, download_url_or_skip_test, testing_data_config +TEST_CASES = ["itkreader", "nibabelreader"] + class TestResampleToMatch(unittest.TestCase): def setUp(self): @@ -28,18 +38,23 @@ def setUp(self): download_url_or_skip_test(url=url, filepath=fname, hash_type=hash_type, hash_val=hash_val) self.fnames.append(fname) - def test_correct(self): + @parameterized.expand(itertools.product([NibabelReader, ITKReader], [NibabelWriter, ITKWriter])) + def test_correct(self, reader, writer): with tempfile.TemporaryDirectory() as temp_dir: - loader = Compose([LoadImaged(("im1", "im2")), EnsureChannelFirstd(("im1", "im2"))]) + loader = Compose([LoadImaged(("im1", "im2"), reader=reader), EnsureChannelFirstd(("im1", "im2"))]) data = loader({"im1": self.fnames[0], "im2": self.fnames[1]}) im_mod, meta = ResampleToMatch()(data["im2"], data["im2_meta_dict"], data["im1_meta_dict"]) - # for visual inspection - saver = SaveImaged("im3", output_dir=temp_dir, output_postfix="", separate_folder=False) + current_dims = copy.deepcopy(meta.get("dim")) + saver = SaveImaged("im3", output_dir=temp_dir, output_postfix="", separate_folder=False, writer=writer) meta["filename_or_obj"] = "file3.nii.gz" saver({"im3": im_mod, "im3_meta_dict": meta}) - assert_allclose(im_mod.shape, data["im1"].shape) + saved = nib.load(os.path.join(temp_dir, meta["filename_or_obj"])) + assert_allclose(data["im1"].shape[1:], saved.shape) + assert_allclose(saved.header["dim"][:4], np.array([3, 384, 384, 19])) + if current_dims is not None: + assert_allclose(saved.header["dim"], current_dims) if __name__ == "__main__": diff --git a/tests/test_resample_to_matchd.py b/tests/test_resample_to_matchd.py index 77bc2c17e6..bf88e0f98a 100644 --- a/tests/test_resample_to_matchd.py +++ b/tests/test_resample_to_matchd.py @@ -46,7 +46,21 @@ def test_correct(self): ] ) data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]}) + # check that output sizes match assert_allclose(data["im1"].shape, data["im3"].shape) + # and that the meta data has been updated accordingly + assert_allclose(data["im3"].shape[1:], data["im3_meta_dict"]["spatial_shape"], type_test=False) + assert_allclose(data["im3_meta_dict"]["affine"], data["im1_meta_dict"]["affine"]) + # check we're different from the original + self.assertTrue(any(i != j for i, j in zip(data["im3"].shape, data["im2"].shape))) + self.assertTrue( + any( + i != j + for i, j in zip( + data["im3_meta_dict"]["affine"].flatten(), data["im2_meta_dict"]["affine"].flatten() + ) + ) + ) if __name__ == "__main__":