Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,24 +293,18 @@ def __call__( # type: ignore
dtype = dtype or self.dtype
src_affine = src_meta.get("affine")
dst_affine = dst_meta.get("affine")
ndim = len(img.shape[1:])
spatial_size = dst_meta.get("dim", [])[1 : ndim + 2]
img, updated_affine = super().__call__(
img=img,
src_affine=src_affine,
dst_affine=dst_affine,
spatial_size=spatial_size,
spatial_size=dst_meta.get("spatial_shape"),
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners,
dtype=dtype,
)
dst_meta = deepcopy(dst_meta)
dst_meta["affine"] = updated_affine
if "dim" in dst_meta:
dst_meta["dim"] = src_meta.get("dim", [])
if "pixdim" in dst_meta:
dst_meta["pixdim"] = src_meta.get("pixdim", [])
Comment thread
rijobro marked this conversation as resolved.
return img, dst_meta


Expand Down
25 changes: 20 additions & 5 deletions tests/test_resample_to_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import itertools
import os
import tempfile
import unittest

import nibabel as nib
import numpy as np
from parameterized import parameterized

from monai.data.image_reader import ITKReader, NibabelReader
from monai.data.image_writer import ITKWriter, NibabelWriter
from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ResampleToMatch, SaveImaged
from tests.utils import assert_allclose, download_url_or_skip_test, testing_data_config

TEST_CASES = ["itkreader", "nibabelreader"]


class TestResampleToMatch(unittest.TestCase):
def setUp(self):
Expand All @@ -28,18 +38,23 @@ def setUp(self):
download_url_or_skip_test(url=url, filepath=fname, hash_type=hash_type, hash_val=hash_val)
self.fnames.append(fname)

def test_correct(self):
@parameterized.expand(itertools.product([NibabelReader, ITKReader], [NibabelWriter, ITKWriter]))
def test_correct(self, reader, writer):
with tempfile.TemporaryDirectory() as temp_dir:
loader = Compose([LoadImaged(("im1", "im2")), EnsureChannelFirstd(("im1", "im2"))])
loader = Compose([LoadImaged(("im1", "im2"), reader=reader), EnsureChannelFirstd(("im1", "im2"))])
data = loader({"im1": self.fnames[0], "im2": self.fnames[1]})

im_mod, meta = ResampleToMatch()(data["im2"], data["im2_meta_dict"], data["im1_meta_dict"])
# for visual inspection
saver = SaveImaged("im3", output_dir=temp_dir, output_postfix="", separate_folder=False)
current_dims = copy.deepcopy(meta.get("dim"))
saver = SaveImaged("im3", output_dir=temp_dir, output_postfix="", separate_folder=False, writer=writer)
meta["filename_or_obj"] = "file3.nii.gz"
saver({"im3": im_mod, "im3_meta_dict": meta})

assert_allclose(im_mod.shape, data["im1"].shape)
saved = nib.load(os.path.join(temp_dir, meta["filename_or_obj"]))
assert_allclose(data["im1"].shape[1:], saved.shape)
assert_allclose(saved.header["dim"][:4], np.array([3, 384, 384, 19]))
if current_dims is not None:
assert_allclose(saved.header["dim"], current_dims)


if __name__ == "__main__":
Expand Down
14 changes: 14 additions & 0 deletions tests/test_resample_to_matchd.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,21 @@ def test_correct(self):
]
)
data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]})
# check that output sizes match
assert_allclose(data["im1"].shape, data["im3"].shape)
# and that the meta data has been updated accordingly
assert_allclose(data["im3"].shape[1:], data["im3_meta_dict"]["spatial_shape"], type_test=False)
assert_allclose(data["im3_meta_dict"]["affine"], data["im1_meta_dict"]["affine"])
# check we're different from the original
self.assertTrue(any(i != j for i, j in zip(data["im3"].shape, data["im2"].shape)))
self.assertTrue(
any(
i != j
for i, j in zip(
data["im3_meta_dict"]["affine"].flatten(), data["im2_meta_dict"]["affine"].flatten()
)
)
)


if __name__ == "__main__":
Expand Down