From 370bd2841a82bebb2fcfb98747d09cec4eea2e0a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 4 Jun 2021 11:48:07 +0100 Subject: [PATCH 1/2] 2304 enhance image dataset to handle meta Signed-off-by: Wenqi Li --- monai/data/image_dataset.py | 40 ++++++++++++++++++++++++---------- monai/transforms/transform.py | 2 ++ tests/test_image_dataset.py | 41 +++++++++++++++++++++++++++++++---- 3 files changed, 67 insertions(+), 16 deletions(-) diff --git a/monai/data/image_dataset.py b/monai/data/image_dataset.py index 1568e082ee..fdde151b2f 100644 --- a/monai/data/image_dataset.py +++ b/monai/data/image_dataset.py @@ -26,7 +26,8 @@ class ImageDataset(Dataset, Randomizable): for the image and segmentation arrays separately. The difference between this dataset and `ArrayDataset` is that this dataset can apply transform chain to images and segs and return both the images and metadata, and no need to specify transform to load images from files. - + For more information, please see the image_dataset demo in the MONAI tutorial repo, + https://github.com/Project-MONAI/tutorials/blob/master/modules/image_dataset.ipynb """ def __init__( @@ -37,6 +38,7 @@ def __init__( transform: Optional[Callable] = None, seg_transform: Optional[Callable] = None, image_only: bool = True, + transform_with_metadata: bool = False, dtype: DtypeLike = np.float32, reader: Optional[Union[ImageReader, str]] = None, *args, @@ -53,6 +55,7 @@ def __init__( transform: transform to apply to image arrays seg_transform: transform to apply to segmentation arrays image_only: if True return only the image volume, otherwise, return image volume and the metadata + transform_with_metadata: if True, the metadata will be passed to the transforms whenever possible. dtype: if not None convert the loaded image to this data type reader: register reader to load image file and meta data, if None, will use the default readers. If a string of reader name provided, will construct a reader object with the `*args` and `**kwargs` @@ -76,7 +79,10 @@ def __init__( self.labels = labels self.transform = transform self.seg_transform = seg_transform + if image_only and transform_with_metadata: + raise ValueError("transform_with_metadata=True requires image_only=False.") self.image_only = image_only + self.transform_with_metadata = transform_with_metadata self.loader = LoadImage(reader, image_only, dtype, *args, **kwargs) self.set_random_state(seed=get_seed()) self._seed = 0 # transform synchronization seed @@ -89,10 +95,9 @@ def randomize(self, data: Optional[Any] = None) -> None: def __getitem__(self, index: int): self.randomize() - meta_data = None - seg = None - label = None + meta_data, seg_meta_data, seg, label = None, None, None, None + # load data and optionally meta if self.image_only: img = self.loader(self.image_files[index]) if self.seg_files is not None: @@ -100,29 +105,40 @@ def __getitem__(self, index: int): else: img, meta_data = self.loader(self.image_files[index]) if self.seg_files is not None: - seg, _ = self.loader(self.seg_files[index]) - - if self.labels is not None: - label = self.labels[index] + seg, seg_meta_data = self.loader(self.seg_files[index]) + # apply the transforms if self.transform is not None: if isinstance(self.transform, Randomizable): self.transform.set_random_state(seed=self._seed) - img = apply_transform(self.transform, img) - - data = [img] + img = apply_transform( + self.transform, (img, meta_data) if self.transform_with_metadata else img, map_items=False + ) + if self.transform_with_metadata: + img, meta_data = img if self.seg_transform is not None: if isinstance(self.seg_transform, Randomizable): self.seg_transform.set_random_state(seed=self._seed) - seg = apply_transform(self.seg_transform, seg) + seg = apply_transform( + self.seg_transform, (seg, seg_meta_data) if self.transform_with_metadata else seg, map_items=False + ) + if self.transform_with_metadata: + seg, seg_meta_data = seg + + if self.labels is not None: + label = self.labels[index] + # construct outputs + data = [img] if seg is not None: data.append(seg) if label is not None: data.append(label) if not self.image_only and meta_data is not None: data.append(meta_data) + if not self.image_only and seg_meta_data is not None: + data.append(seg_meta_data) if len(data) == 1: return data[0] # use tuple instead of list as the default collate_fn callback of MONAI DataLoader flattens nested lists diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index e5715ee702..d8f12ee02e 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -46,6 +46,8 @@ def apply_transform(transform: Callable, data, map_items: bool = True): try: if isinstance(data, (list, tuple)) and map_items: return [transform(item) for item in data] + if isinstance(data, (list, tuple)) and not map_items: + return transform(*data) return transform(data) except Exception as e: diff --git a/tests/test_image_dataset.py b/tests/test_image_dataset.py index ec2cf77cd8..3b3c06c87c 100644 --- a/tests/test_image_dataset.py +++ b/tests/test_image_dataset.py @@ -17,7 +17,7 @@ import numpy as np from monai.data import ImageDataset -from monai.transforms.transform import RandomizableTransform +from monai.transforms import Compose, EnsureChannelFirst, RandAdjustContrast, RandomizableTransform, Spacing FILENAMES = ["test1.nii.gz", "test2.nii", "test3.nii.gz"] @@ -35,7 +35,38 @@ def __call__(self, data): return data + self._a +class _TestCompose(Compose): + def __call__(self, data, meta): + data = self.transforms[0](data, meta) # ensure channel first + data, _, meta["affine"] = self.transforms[1](data, meta["affine"]) # spacing + if len(self.transforms) == 3: + return self.transforms[2](data), meta # image contrast + return data, meta + + class TestImageDataset(unittest.TestCase): + def test_use_case(self): + with tempfile.TemporaryDirectory() as tempdir: + img_ = nib.Nifti1Image(np.random.randint(0, 2, size=(20, 20, 20)), np.eye(4)) + seg_ = nib.Nifti1Image(np.random.randint(0, 2, size=(20, 20, 20)), np.eye(4)) + img_name, seg_name = os.path.join(tempdir, "img.nii.gz"), os.path.join(tempdir, "seg.nii.gz") + nib.save(img_, img_name) + nib.save(seg_, seg_name) + img_list, seg_list = [img_name], [seg_name] + + img_xform = _TestCompose([EnsureChannelFirst(), Spacing(pixdim=(1.5, 1.5, 3.0)), RandAdjustContrast()]) + seg_xform = _TestCompose([EnsureChannelFirst(), Spacing(pixdim=(1.5, 1.5, 3.0), mode="nearest")]) + img_dataset = ImageDataset( + image_files=img_list, + seg_files=seg_list, + transform=img_xform, + seg_transform=seg_xform, + image_only=False, + transform_with_metadata=True, + ) + self.assertTupleEqual(img_dataset[0][0].shape, (1, 14, 14, 7)) + self.assertTupleEqual(img_dataset[0][1].shape, (1, 14, 14, 7)) + def test_dataset(self): with tempfile.TemporaryDirectory() as tempdir: full_names, ref_data = [], [] @@ -94,28 +125,30 @@ def test_dataset(self): image_only=False, ) for d_tuple, ref in zip(dataset, ref_data): - img, seg, meta = d_tuple + img, seg, meta, seg_meta = d_tuple np.testing.assert_allclose(img, ref + 1, atol=1e-3) np.testing.assert_allclose(seg, ref + 2, atol=1e-3) np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3) + np.testing.assert_allclose(seg_meta["original_affine"], np.eye(4), atol=1e-3) # loading image/label, with meta dataset = ImageDataset( full_names, transform=lambda x: x + 1, seg_files=full_names, labels=[1, 2, 3], image_only=False ) for idx, (d_tuple, ref) in enumerate(zip(dataset, ref_data)): - img, seg, label, meta = d_tuple + img, seg, label, meta, seg_meta = d_tuple np.testing.assert_allclose(img, ref + 1, atol=1e-3) np.testing.assert_allclose(seg, ref, atol=1e-3) np.testing.assert_allclose(idx + 1, label) np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3) + np.testing.assert_allclose(seg_meta["original_affine"], np.eye(4), atol=1e-3) # loading image/label, with sync. transform dataset = ImageDataset( full_names, transform=RandTest(), seg_files=full_names, seg_transform=RandTest(), image_only=False ) for d_tuple, ref in zip(dataset, ref_data): - img, seg, meta = d_tuple + img, seg, meta, seg_meta = d_tuple np.testing.assert_allclose(img, seg, atol=1e-3) self.assertTrue(not np.allclose(img, ref)) np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3) From 5d9383e3a97e1c42a81a7699ced1e52bd05ccb8e Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 4 Jun 2021 11:57:01 +0100 Subject: [PATCH 2/2] fixes typo Signed-off-by: Wenqi Li --- monai/transforms/transform.py | 2 -- tests/test_image_dataset.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index d8f12ee02e..e5715ee702 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -46,8 +46,6 @@ def apply_transform(transform: Callable, data, map_items: bool = True): try: if isinstance(data, (list, tuple)) and map_items: return [transform(item) for item in data] - if isinstance(data, (list, tuple)) and not map_items: - return transform(*data) return transform(data) except Exception as e: diff --git a/tests/test_image_dataset.py b/tests/test_image_dataset.py index 3b3c06c87c..173d24f350 100644 --- a/tests/test_image_dataset.py +++ b/tests/test_image_dataset.py @@ -36,7 +36,8 @@ def __call__(self, data): class _TestCompose(Compose): - def __call__(self, data, meta): + def __call__(self, input_): + data, meta = input_ data = self.transforms[0](data, meta) # ensure channel first data, _, meta["affine"] = self.transforms[1](data, meta["affine"]) # spacing if len(self.transforms) == 3: