diff --git a/monai/data/utils.py b/monai/data/utils.py index 368f51be5f..49f64df1b9 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -288,10 +288,10 @@ def list_data_collate(batch: Sequence): raise TypeError(re_str) -def decollate_batch(data: dict, batch_size: Optional[int] = None) -> List[dict]: +def decollate_batch(data: Union[dict, list, torch.Tensor], batch_size: Optional[int] = None) -> List[dict]: """De-collate a batch of data (for example, as produced by a `DataLoader`). - Returns a list of dictionaries. Each dictionary will only contain the data for a given batch. + Returns a list of dictionaries, list or Tensor, mapping to a given batch. Images originally stored as (B,C,H,W,[D]) will be returned as (C,H,W,[D]). Other information, such as metadata, may have been stored in a list (or a list inside nested dictionaries). In @@ -315,19 +315,20 @@ def decollate_batch(data: dict, batch_size: Optional[int] = None) -> List[dict]: print(out[0]) >>> {'image': tensor([[[4.3549e-01...43e-01]]]), 'image_meta_dict': {'scl_slope': 0.0}} + batch_data = [torch.rand((2,1,10,10), torch.rand((2,3,5,5)] + out = decollate_batch(batch_data) + print(out[0]) + >>> [tensor([[[4.3549e-01...43e-01]]], tensor([[[5.3435e-01...45e-01]]])] + + batch_data = torch.rand((2,1,10,10) + out = decollate_batch(batch_data) + print(out[0]) + >>> tensor([[[4.3549e-01...43e-01]]]) + Args: data: data to be de-collated. batch_size: number of batches in data. If `None` is passed, try to figure out batch size. """ - if not isinstance(data, dict): - raise RuntimeError("Only currently implemented for dictionary data (might be trivial to adapt).") - if batch_size is None: - for v in data.values(): - if isinstance(v, torch.Tensor): - batch_size = v.shape[0] - break - if batch_size is None: - raise RuntimeError("Couldn't determine batch size, please specify as argument.") def torch_to_single(d: torch.Tensor): """If input is a torch.Tensor with only 1 element, return just the element.""" @@ -350,7 +351,27 @@ def decollate(data: Any, idx: int): return data[idx] raise TypeError(f"Not sure how to de-collate type: {type(data)}") - return [{key: decollate(data[key], idx) for key in data.keys()} for idx in range(batch_size)] + def _detect_batch_size(batch_data): + for v in batch_data: + if isinstance(v, torch.Tensor): + return v.shape[0] + warnings.warn("batch_data is not a sequence of tensors in decollate, use `len(batch_data[0])` directly.") + return len(batch_data[0]) + + result: List[Any] + if isinstance(data, dict): + batch_size = _detect_batch_size(batch_data=data.values()) if batch_size is None else batch_size + result = [{key: decollate(data[key], idx) for key in data.keys()} for idx in range(batch_size)] + elif isinstance(data, list): + batch_size = _detect_batch_size(batch_data=data) if batch_size is None else batch_size + result = [[decollate(d, idx) for d in data] for idx in range(batch_size)] + elif isinstance(data, torch.Tensor): + batch_size = data.shape[0] + result = [data[idx] for idx in range(batch_size)] + else: + raise NotImplementedError("Only currently implemented for dictionary, list or Tensor data.") + + return result def pad_list_data_collate( diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 5b78bbbcf6..ad0a0ecdc3 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -18,9 +18,23 @@ import torch from parameterized import parameterized -from monai.data import CacheDataset, DataLoader, create_test_image_2d +from monai.data import CacheDataset, DataLoader, Dataset, create_test_image_2d from monai.data.utils import decollate_batch -from monai.transforms import AddChanneld, Compose, LoadImaged, RandFlipd, SpatialPadd, ToTensord +from monai.transforms import ( + AddChannel, + AddChanneld, + Compose, + LoadImage, + LoadImaged, + RandAffine, + RandFlip, + RandFlipd, + RandRotate90, + SpatialPad, + SpatialPadd, + ToTensor, + ToTensord, +) from monai.transforms.post.dictionary import Decollated from monai.transforms.spatial.dictionary import RandAffined, RandRotate90d from monai.utils import optional_import, set_determinism @@ -31,10 +45,23 @@ KEYS = ["image"] -TESTS: List[Tuple] = [] -TESTS.append((SpatialPadd(KEYS, 150), RandFlipd(KEYS, prob=1.0, spatial_axis=1))) -TESTS.append((RandRotate90d(KEYS, prob=0.0, max_k=1),)) -TESTS.append((RandAffined(KEYS, prob=0.0, translate_range=10),)) +TESTS_DICT: List[Tuple] = [] +TESTS_DICT.append((SpatialPadd(KEYS, 150), RandFlipd(KEYS, prob=1.0, spatial_axis=1))) +TESTS_DICT.append((RandRotate90d(KEYS, prob=0.0, max_k=1),)) +TESTS_DICT.append((RandAffined(KEYS, prob=0.0, translate_range=10),)) + +TESTS_LIST: List[Tuple] = [] +TESTS_LIST.append((SpatialPad(150), RandFlip(prob=1.0, spatial_axis=1))) +TESTS_LIST.append((RandRotate90(prob=0.0, max_k=1),)) +TESTS_LIST.append((RandAffine(prob=0.0, translate_range=10),)) + + +class _ListCompose(Compose): + def __call__(self, input_): + img, metadata = self.transforms[0](input_) + for t in self.transforms[1:]: + img = t(img) + return img, metadata class TestDeCollate(unittest.TestCase): @@ -42,7 +69,8 @@ def setUp(self) -> None: set_determinism(seed=0) im = create_test_image_2d(100, 101)[0] - self.data = [{"image": make_nifti_image(im) if has_nib else im} for _ in range(6)] + self.data_dict = [{"image": make_nifti_image(im) if has_nib else im} for _ in range(6)] + self.data_list = [make_nifti_image(im) if has_nib else im for _ in range(6)] def tearDown(self) -> None: set_determinism(None) @@ -68,18 +96,10 @@ def check_match(self, in1, in2): else: raise RuntimeError(f"Not sure how to compare types. type(in1): {type(in1)}, type(in2): {type(in2)}") - @parameterized.expand(TESTS) - def test_decollation(self, *transforms): - + def check_decollate(self, dataset): batch_size = 2 num_workers = 2 - t_compose = Compose([AddChanneld(KEYS), Compose(transforms), ToTensord(KEYS)]) - # If nibabel present, read from disk - if has_nib: - t_compose = Compose([LoadImaged("image"), t_compose]) - - dataset = CacheDataset(self.data, t_compose, progress=False) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) for b, batch_data in enumerate(loader): @@ -90,6 +110,36 @@ def test_decollation(self, *transforms): for i, d in enumerate(decollated): self.check_match(dataset[b * batch_size + i], d) + @parameterized.expand(TESTS_DICT) + def test_decollation_dict(self, *transforms): + t_compose = Compose([AddChanneld(KEYS), Compose(transforms), ToTensord(KEYS)]) + # If nibabel present, read from disk + if has_nib: + t_compose = Compose([LoadImaged("image"), t_compose]) + + dataset = CacheDataset(self.data_dict, t_compose, progress=False) + self.check_decollate(dataset=dataset) + + @parameterized.expand(TESTS_LIST) + def test_decollation_tensor(self, *transforms): + t_compose = Compose([AddChannel(), Compose(transforms), ToTensor()]) + # If nibabel present, read from disk + if has_nib: + t_compose = Compose([LoadImage(image_only=True), t_compose]) + + dataset = Dataset(self.data_list, t_compose) + self.check_decollate(dataset=dataset) + + @parameterized.expand(TESTS_LIST) + def test_decollation_list(self, *transforms): + t_compose = Compose([AddChannel(), Compose(transforms), ToTensor()]) + # If nibabel present, read from disk + if has_nib: + t_compose = _ListCompose([LoadImage(image_only=False), t_compose]) + + dataset = Dataset(self.data_list, t_compose) + self.check_decollate(dataset=dataset) + if __name__ == "__main__": unittest.main()