Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions monai/data/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class ImageDataset(Dataset, Randomizable):
for the image and segmentation arrays separately.
The difference between this dataset and `ArrayDataset` is that this dataset can apply transform chain to images
and segs and return both the images and metadata, and no need to specify transform to load images from files.

For more information, please see the image_dataset demo in the MONAI tutorial repo,
https://github.com/Project-MONAI/tutorials/blob/master/modules/image_dataset.ipynb
"""

def __init__(
Expand All @@ -37,6 +38,7 @@ def __init__(
transform: Optional[Callable] = None,
seg_transform: Optional[Callable] = None,
image_only: bool = True,
transform_with_metadata: bool = False,
dtype: DtypeLike = np.float32,
reader: Optional[Union[ImageReader, str]] = None,
*args,
Expand All @@ -53,6 +55,7 @@ def __init__(
transform: transform to apply to image arrays
seg_transform: transform to apply to segmentation arrays
image_only: if True return only the image volume, otherwise, return image volume and the metadata
transform_with_metadata: if True, the metadata will be passed to the transforms whenever possible.
dtype: if not None convert the loaded image to this data type
reader: register reader to load image file and meta data, if None, will use the default readers.
If a string of reader name provided, will construct a reader object with the `*args` and `**kwargs`
Expand All @@ -76,7 +79,10 @@ def __init__(
self.labels = labels
self.transform = transform
self.seg_transform = seg_transform
if image_only and transform_with_metadata:
raise ValueError("transform_with_metadata=True requires image_only=False.")
self.image_only = image_only
self.transform_with_metadata = transform_with_metadata
self.loader = LoadImage(reader, image_only, dtype, *args, **kwargs)
self.set_random_state(seed=get_seed())
self._seed = 0 # transform synchronization seed
Expand All @@ -89,40 +95,50 @@ def randomize(self, data: Optional[Any] = None) -> None:

def __getitem__(self, index: int):
self.randomize()
meta_data = None
seg = None
label = None
meta_data, seg_meta_data, seg, label = None, None, None, None

# load data and optionally meta
if self.image_only:
img = self.loader(self.image_files[index])
if self.seg_files is not None:
seg = self.loader(self.seg_files[index])
else:
img, meta_data = self.loader(self.image_files[index])
if self.seg_files is not None:
seg, _ = self.loader(self.seg_files[index])

if self.labels is not None:
label = self.labels[index]
seg, seg_meta_data = self.loader(self.seg_files[index])

# apply the transforms
if self.transform is not None:
if isinstance(self.transform, Randomizable):
self.transform.set_random_state(seed=self._seed)
img = apply_transform(self.transform, img)

data = [img]
img = apply_transform(
self.transform, (img, meta_data) if self.transform_with_metadata else img, map_items=False
)
if self.transform_with_metadata:
img, meta_data = img

if self.seg_transform is not None:
if isinstance(self.seg_transform, Randomizable):
self.seg_transform.set_random_state(seed=self._seed)
seg = apply_transform(self.seg_transform, seg)
seg = apply_transform(
self.seg_transform, (seg, seg_meta_data) if self.transform_with_metadata else seg, map_items=False
)
if self.transform_with_metadata:
seg, seg_meta_data = seg

if self.labels is not None:
label = self.labels[index]

# construct outputs
data = [img]
if seg is not None:
data.append(seg)
if label is not None:
data.append(label)
if not self.image_only and meta_data is not None:
data.append(meta_data)
if not self.image_only and seg_meta_data is not None:
data.append(seg_meta_data)
if len(data) == 1:
return data[0]
# use tuple instead of list as the default collate_fn callback of MONAI DataLoader flattens nested lists
Expand Down
42 changes: 38 additions & 4 deletions tests/test_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np

from monai.data import ImageDataset
from monai.transforms.transform import RandomizableTransform
from monai.transforms import Compose, EnsureChannelFirst, RandAdjustContrast, RandomizableTransform, Spacing

FILENAMES = ["test1.nii.gz", "test2.nii", "test3.nii.gz"]

Expand All @@ -35,7 +35,39 @@ def __call__(self, data):
return data + self._a


class _TestCompose(Compose):
def __call__(self, input_):
data, meta = input_
data = self.transforms[0](data, meta) # ensure channel first
data, _, meta["affine"] = self.transforms[1](data, meta["affine"]) # spacing
if len(self.transforms) == 3:
return self.transforms[2](data), meta # image contrast
return data, meta


class TestImageDataset(unittest.TestCase):
def test_use_case(self):
with tempfile.TemporaryDirectory() as tempdir:
img_ = nib.Nifti1Image(np.random.randint(0, 2, size=(20, 20, 20)), np.eye(4))
seg_ = nib.Nifti1Image(np.random.randint(0, 2, size=(20, 20, 20)), np.eye(4))
img_name, seg_name = os.path.join(tempdir, "img.nii.gz"), os.path.join(tempdir, "seg.nii.gz")
nib.save(img_, img_name)
nib.save(seg_, seg_name)
img_list, seg_list = [img_name], [seg_name]

img_xform = _TestCompose([EnsureChannelFirst(), Spacing(pixdim=(1.5, 1.5, 3.0)), RandAdjustContrast()])
seg_xform = _TestCompose([EnsureChannelFirst(), Spacing(pixdim=(1.5, 1.5, 3.0), mode="nearest")])
img_dataset = ImageDataset(
image_files=img_list,
seg_files=seg_list,
transform=img_xform,
seg_transform=seg_xform,
image_only=False,
transform_with_metadata=True,
)
self.assertTupleEqual(img_dataset[0][0].shape, (1, 14, 14, 7))
self.assertTupleEqual(img_dataset[0][1].shape, (1, 14, 14, 7))

def test_dataset(self):
with tempfile.TemporaryDirectory() as tempdir:
full_names, ref_data = [], []
Expand Down Expand Up @@ -94,28 +126,30 @@ def test_dataset(self):
image_only=False,
)
for d_tuple, ref in zip(dataset, ref_data):
img, seg, meta = d_tuple
img, seg, meta, seg_meta = d_tuple
np.testing.assert_allclose(img, ref + 1, atol=1e-3)
np.testing.assert_allclose(seg, ref + 2, atol=1e-3)
np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3)
np.testing.assert_allclose(seg_meta["original_affine"], np.eye(4), atol=1e-3)

# loading image/label, with meta
dataset = ImageDataset(
full_names, transform=lambda x: x + 1, seg_files=full_names, labels=[1, 2, 3], image_only=False
)
for idx, (d_tuple, ref) in enumerate(zip(dataset, ref_data)):
img, seg, label, meta = d_tuple
img, seg, label, meta, seg_meta = d_tuple
np.testing.assert_allclose(img, ref + 1, atol=1e-3)
np.testing.assert_allclose(seg, ref, atol=1e-3)
np.testing.assert_allclose(idx + 1, label)
np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3)
np.testing.assert_allclose(seg_meta["original_affine"], np.eye(4), atol=1e-3)

# loading image/label, with sync. transform
dataset = ImageDataset(
full_names, transform=RandTest(), seg_files=full_names, seg_transform=RandTest(), image_only=False
)
for d_tuple, ref in zip(dataset, ref_data):
img, seg, meta = d_tuple
img, seg, meta, seg_meta = d_tuple
np.testing.assert_allclose(img, seg, atol=1e-3)
self.assertTrue(not np.allclose(img, ref))
np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3)
Expand Down