Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import pickle
import warnings
from collections import defaultdict
from collections import abc, defaultdict
from itertools import product, starmap
from pathlib import PurePath
from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -254,10 +254,20 @@ def list_data_collate(batch: Sequence):
elem = batch[0]
data = [i for k in batch for i in k] if isinstance(elem, list) else batch
try:
elem = batch[0]
key = None
if isinstance(elem, abc.Mapping):
ret = {}
for k in elem:
key = k
ret[k] = default_collate([d[k] for d in data])
return ret
return default_collate(data)
except RuntimeError as re:
re_str = str(re)
if "equal size" in re_str:
if key is not None:
re_str += f"\nCollate error on the key '{key}' of dictionary data."
re_str += (
"\n\nMONAI hint: if your transforms intentionally create images of different shapes, creating your "
+ "`DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its "
Expand All @@ -267,6 +277,8 @@ def list_data_collate(batch: Sequence):
except TypeError as re:
re_str = str(re)
if "numpy" in re_str and "Tensor" in re_str:
if key is not None:
re_str += f"\nCollate error on the key '{key}' of dictionary data."
re_str += (
"\n\nMONAI hint: if your transforms intentionally create mixtures of torch Tensor and numpy ndarray, "
+ "creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem "
Expand Down
28 changes: 27 additions & 1 deletion tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,27 @@
import sys
import unittest

from monai.data import CacheDataset, DataLoader
import numpy as np
import torch
from parameterized import parameterized

from monai.data import CacheDataset, DataLoader, Dataset
from monai.transforms import Compose, DataStatsd, SimulateDelayd

TEST_CASE_1 = [
[
{"image": np.asarray([1, 2, 3])},
{"image": np.asarray([4, 5])},
]
]

TEST_CASE_2 = [
[
{"label": torch.as_tensor([[3], [2]])},
{"label": np.asarray([[1], [2]])},
]
]


class TestDataLoader(unittest.TestCase):
def test_values(self):
Expand All @@ -37,6 +55,14 @@ def test_values(self):
self.assertEqual(d["label"][0], "spleen_label_19.nii.gz")
self.assertEqual(d["label"][1], "spleen_label_31.nii.gz")

@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_exception(self, datalist):
dataset = Dataset(data=datalist, transform=None)
dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0)
with self.assertRaisesRegex((TypeError, RuntimeError), "Collate error on the key"):
for _ in dataloader:
pass


if __name__ == "__main__":
unittest.main()