diff --git a/monai/data/utils.py b/monai/data/utils.py index 63e630fe17..47108c68ef 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -15,7 +15,7 @@ import os import pickle import warnings -from collections import defaultdict +from collections import abc, defaultdict from itertools import product, starmap from pathlib import PurePath from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union @@ -254,10 +254,20 @@ def list_data_collate(batch: Sequence): elem = batch[0] data = [i for k in batch for i in k] if isinstance(elem, list) else batch try: + elem = batch[0] + key = None + if isinstance(elem, abc.Mapping): + ret = {} + for k in elem: + key = k + ret[k] = default_collate([d[k] for d in data]) + return ret return default_collate(data) except RuntimeError as re: re_str = str(re) if "equal size" in re_str: + if key is not None: + re_str += f"\nCollate error on the key '{key}' of dictionary data." re_str += ( "\n\nMONAI hint: if your transforms intentionally create images of different shapes, creating your " + "`DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its " @@ -267,6 +277,8 @@ def list_data_collate(batch: Sequence): except TypeError as re: re_str = str(re) if "numpy" in re_str and "Tensor" in re_str: + if key is not None: + re_str += f"\nCollate error on the key '{key}' of dictionary data." re_str += ( "\n\nMONAI hint: if your transforms intentionally create mixtures of torch Tensor and numpy ndarray, " + "creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem " diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 072a4a01c0..53e7c89f67 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -12,9 +12,27 @@ import sys import unittest -from monai.data import CacheDataset, DataLoader +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import CacheDataset, DataLoader, Dataset from monai.transforms import Compose, DataStatsd, SimulateDelayd +TEST_CASE_1 = [ + [ + {"image": np.asarray([1, 2, 3])}, + {"image": np.asarray([4, 5])}, + ] +] + +TEST_CASE_2 = [ + [ + {"label": torch.as_tensor([[3], [2]])}, + {"label": np.asarray([[1], [2]])}, + ] +] + class TestDataLoader(unittest.TestCase): def test_values(self): @@ -37,6 +55,14 @@ def test_values(self): self.assertEqual(d["label"][0], "spleen_label_19.nii.gz") self.assertEqual(d["label"][1], "spleen_label_31.nii.gz") + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_exception(self, datalist): + dataset = Dataset(data=datalist, transform=None) + dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0) + with self.assertRaisesRegex((TypeError, RuntimeError), "Collate error on the key"): + for _ in dataloader: + pass + if __name__ == "__main__": unittest.main()