From 7db4144e9e11cb339fef1daf51d9e8a8cf0828e1 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 23 Feb 2022 11:42:43 +0000 Subject: [PATCH 1/6] enhance to use spatial_size Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 6 +++--- tests/test_resample_to_match.py | 14 +++++++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 0fb96e88a1..c874dafdea 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -293,13 +293,11 @@ def __call__( # type: ignore dtype = dtype or self.dtype src_affine = src_meta.get("affine") dst_affine = dst_meta.get("affine") - ndim = len(img.shape[1:]) - spatial_size = dst_meta.get("dim", [])[1 : ndim + 2] img, updated_affine = super().__call__( img=img, src_affine=src_affine, dst_affine=dst_affine, - spatial_size=spatial_size, + spatial_size=dst_meta.get("spatial_shape"), mode=mode, padding_mode=padding_mode, align_corners=align_corners, @@ -307,6 +305,8 @@ def __call__( # type: ignore ) dst_meta = deepcopy(dst_meta) dst_meta["affine"] = updated_affine + if "spatial_shape" in dst_meta: + dst_meta["spatial_shape"] = src_meta.get("spatial_shape") if "dim" in dst_meta: dst_meta["dim"] = src_meta.get("dim", []) if "pixdim" in dst_meta: diff --git a/tests/test_resample_to_match.py b/tests/test_resample_to_match.py index cf8c98d588..451329c3b0 100644 --- a/tests/test_resample_to_match.py +++ b/tests/test_resample_to_match.py @@ -9,13 +9,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import os import tempfile import unittest +from parameterized import parameterized + +from monai.data.image_reader import ITKReader, NibabelReader +from monai.data.image_writer import ITKWriter, NibabelWriter from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ResampleToMatch, SaveImaged from tests.utils import assert_allclose, download_url_or_skip_test, testing_data_config +TEST_CASES = ["itkreader", "nibabelreader"] + class TestResampleToMatch(unittest.TestCase): def setUp(self): @@ -28,14 +35,15 @@ def setUp(self): download_url_or_skip_test(url=url, filepath=fname, hash_type=hash_type, hash_val=hash_val) self.fnames.append(fname) - def test_correct(self): + @parameterized.expand(itertools.product([NibabelReader, ITKReader], [NibabelWriter, ITKWriter])) + def test_correct(self, reader, writer): with tempfile.TemporaryDirectory() as temp_dir: - loader = Compose([LoadImaged(("im1", "im2")), EnsureChannelFirstd(("im1", "im2"))]) + loader = Compose([LoadImaged(("im1", "im2"), reader=reader), EnsureChannelFirstd(("im1", "im2"))]) data = loader({"im1": self.fnames[0], "im2": self.fnames[1]}) im_mod, meta = ResampleToMatch()(data["im2"], data["im2_meta_dict"], data["im1_meta_dict"]) # for visual inspection - saver = SaveImaged("im3", output_dir=temp_dir, output_postfix="", separate_folder=False) + saver = SaveImaged("im3", output_dir=temp_dir, output_postfix="", separate_folder=False, writer=writer) meta["filename_or_obj"] = "file3.nii.gz" saver({"im3": im_mod, "im3_meta_dict": meta}) From e518e4b072bdf6b1163e3bafd990a847c7a3dfbf Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 23 Feb 2022 13:29:39 +0000 Subject: [PATCH 2/6] update tests Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 6 ------ tests/test_resample_to_match.py | 6 +++++- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index c874dafdea..f7327aa07b 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -305,12 +305,6 @@ def __call__( # type: ignore ) dst_meta = deepcopy(dst_meta) dst_meta["affine"] = updated_affine - if "spatial_shape" in dst_meta: - dst_meta["spatial_shape"] = src_meta.get("spatial_shape") - if "dim" in dst_meta: - dst_meta["dim"] = src_meta.get("dim", []) - if "pixdim" in dst_meta: - dst_meta["pixdim"] = src_meta.get("pixdim", []) return img, dst_meta diff --git a/tests/test_resample_to_match.py b/tests/test_resample_to_match.py index 451329c3b0..742bf338fe 100644 --- a/tests/test_resample_to_match.py +++ b/tests/test_resample_to_match.py @@ -14,6 +14,8 @@ import tempfile import unittest +import nibabel as nib +import numpy as np from parameterized import parameterized from monai.data.image_reader import ITKReader, NibabelReader @@ -47,7 +49,9 @@ def test_correct(self, reader, writer): meta["filename_or_obj"] = "file3.nii.gz" saver({"im3": im_mod, "im3_meta_dict": meta}) - assert_allclose(im_mod.shape, data["im1"].shape) + saved = nib.load(os.path.join(temp_dir, meta["filename_or_obj"])) + assert_allclose(im_mod.shape[1:], saved.shape) + assert_allclose(saved.header["dim"][:4], np.array([3, 384, 384, 19])) if __name__ == "__main__": From a43ae3988bc373083858e2ddb8d79ee4c9843028 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 23 Feb 2022 15:31:52 +0000 Subject: [PATCH 3/6] fixes unit test Signed-off-by: Wenqi Li --- tests/test_resample_to_match.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_resample_to_match.py b/tests/test_resample_to_match.py index 742bf338fe..484bb1e449 100644 --- a/tests/test_resample_to_match.py +++ b/tests/test_resample_to_match.py @@ -44,13 +44,12 @@ def test_correct(self, reader, writer): data = loader({"im1": self.fnames[0], "im2": self.fnames[1]}) im_mod, meta = ResampleToMatch()(data["im2"], data["im2_meta_dict"], data["im1_meta_dict"]) - # for visual inspection saver = SaveImaged("im3", output_dir=temp_dir, output_postfix="", separate_folder=False, writer=writer) meta["filename_or_obj"] = "file3.nii.gz" saver({"im3": im_mod, "im3_meta_dict": meta}) saved = nib.load(os.path.join(temp_dir, meta["filename_or_obj"])) - assert_allclose(im_mod.shape[1:], saved.shape) + assert_allclose(data["im1"].shape[1:], saved.shape) assert_allclose(saved.header["dim"][:4], np.array([3, 384, 384, 19])) From 12bb4a28672be301cce4d8017e6b5c13c278325a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 24 Feb 2022 14:28:18 +0000 Subject: [PATCH 4/6] adds unit tests Signed-off-by: Wenqi Li --- tests/test_resample_to_match.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_resample_to_match.py b/tests/test_resample_to_match.py index 484bb1e449..b0a5027071 100644 --- a/tests/test_resample_to_match.py +++ b/tests/test_resample_to_match.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import itertools import os import tempfile @@ -44,6 +45,7 @@ def test_correct(self, reader, writer): data = loader({"im1": self.fnames[0], "im2": self.fnames[1]}) im_mod, meta = ResampleToMatch()(data["im2"], data["im2_meta_dict"], data["im1_meta_dict"]) + current_dims = copy.deepcopy(meta.get("dim")) saver = SaveImaged("im3", output_dir=temp_dir, output_postfix="", separate_folder=False, writer=writer) meta["filename_or_obj"] = "file3.nii.gz" saver({"im3": im_mod, "im3_meta_dict": meta}) @@ -51,6 +53,8 @@ def test_correct(self, reader, writer): saved = nib.load(os.path.join(temp_dir, meta["filename_or_obj"])) assert_allclose(data["im1"].shape[1:], saved.shape) assert_allclose(saved.header["dim"][:4], np.array([3, 384, 384, 19])) + if current_dims is not None: + assert_allclose(saved.header["dim"], current_dims) if __name__ == "__main__": From 0c00f9e8f97c0fcb7035dee97e25889a508d144e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 24 Feb 2022 17:09:40 +0000 Subject: [PATCH 5/6] check that metadata matches Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_resample_to_matchd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_resample_to_matchd.py b/tests/test_resample_to_matchd.py index 77bc2c17e6..6071405607 100644 --- a/tests/test_resample_to_matchd.py +++ b/tests/test_resample_to_matchd.py @@ -47,6 +47,7 @@ def test_correct(self): ) data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]}) assert_allclose(data["im1"].shape, data["im3"].shape) + assert_allclose(data["im3"].shape[1:], data["im3_meta_dict"]["spatial_shape"], type_test=False) if __name__ == "__main__": From 8fbfc85d6dd73c37766935437100387f4652e0b6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 24 Feb 2022 17:20:58 +0000 Subject: [PATCH 6/6] add checks to test Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_resample_to_matchd.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_resample_to_matchd.py b/tests/test_resample_to_matchd.py index 6071405607..bf88e0f98a 100644 --- a/tests/test_resample_to_matchd.py +++ b/tests/test_resample_to_matchd.py @@ -46,8 +46,21 @@ def test_correct(self): ] ) data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]}) + # check that output sizes match assert_allclose(data["im1"].shape, data["im3"].shape) + # and that the meta data has been updated accordingly assert_allclose(data["im3"].shape[1:], data["im3_meta_dict"]["spatial_shape"], type_test=False) + assert_allclose(data["im3_meta_dict"]["affine"], data["im1_meta_dict"]["affine"]) + # check we're different from the original + self.assertTrue(any(i != j for i, j in zip(data["im3"].shape, data["im2"].shape))) + self.assertTrue( + any( + i != j + for i, j in zip( + data["im3_meta_dict"]["affine"].flatten(), data["im2_meta_dict"]["affine"].flatten() + ) + ) + ) if __name__ == "__main__":