diff --git a/tests/test_make_nifti.py b/tests/test_make_nifti.py new file mode 100644 index 0000000000..951f079764 --- /dev/null +++ b/tests/test_make_nifti.py @@ -0,0 +1,43 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data.synthetic import create_test_image_2d +from monai.utils import optional_import +from tests.utils import make_nifti_image + +_, has_nib = optional_import("nibabel") + +TESTS = [] +for affine in (None, np.eye(4), torch.eye(4)): + for dir in (None, tempfile.mkdtemp()): + for fname in (None, "fname"): + TESTS.append([{"affine": affine, "dir": dir, "fname": fname}]) + + +@unittest.skipUnless(has_nib, "Requires nibabel") +class TestMakeNifti(unittest.TestCase): + @parameterized.expand(TESTS) + def test_make_nifti(self, params): + im, _ = create_test_image_2d(100, 88) + created_file = make_nifti_image(im, verbose=True, **params) + self.assertTrue(os.path.isfile(created_file)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index f96f659353..8a07736d6c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -240,7 +240,7 @@ def has_cupy(): HAS_CUPY = has_cupy() -def make_nifti_image(array: NdarrayOrTensor, affine=None): +def make_nifti_image(array: NdarrayOrTensor, affine=None, dir=None, fname=None, suffix=".nii.gz", verbose=False): """ Create a temporary nifti image on the disk and return the image name. User is responsible for deleting the temporary file when done with it. @@ -253,10 +253,23 @@ def make_nifti_image(array: NdarrayOrTensor, affine=None): affine = np.eye(4) test_image = nib.Nifti1Image(array, affine) - temp_f, image_name = tempfile.mkstemp(suffix=".nii.gz") - nib.save(test_image, image_name) - os.close(temp_f) - return image_name + # if dir not given, create random. Else, make sure it exists. + if dir is None: + dir = tempfile.mkdtemp() + else: + os.makedirs(dir, exist_ok=True) + + # If fname not given, get random one. Else, concat dir, fname and suffix. + if fname is None: + temp_f, fname = tempfile.mkstemp(suffix=suffix, dir=dir) + os.close(temp_f) + else: + fname = os.path.join(dir, fname + suffix) + + nib.save(test_image, fname) + if verbose: + print(f"File written: {fname}.") + return fname def make_rand_affine(ndim: int = 3, random_state: Optional[np.random.RandomState] = None):