Skip to content
45 changes: 33 additions & 12 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,10 @@ def list_data_collate(batch: Sequence):
raise TypeError(re_str)


def decollate_batch(data: dict, batch_size: Optional[int] = None) -> List[dict]:
def decollate_batch(data: Union[dict, list, torch.Tensor], batch_size: Optional[int] = None) -> List[dict]:
"""De-collate a batch of data (for example, as produced by a `DataLoader`).

Returns a list of dictionaries. Each dictionary will only contain the data for a given batch.
Returns a list of dictionaries, list or Tensor, mapping to a given batch.

Images originally stored as (B,C,H,W,[D]) will be returned as (C,H,W,[D]). Other information,
such as metadata, may have been stored in a list (or a list inside nested dictionaries). In
Expand All @@ -315,19 +315,20 @@ def decollate_batch(data: dict, batch_size: Optional[int] = None) -> List[dict]:
print(out[0])
>>> {'image': tensor([[[4.3549e-01...43e-01]]]), 'image_meta_dict': {'scl_slope': 0.0}}

batch_data = [torch.rand((2,1,10,10), torch.rand((2,3,5,5)]
out = decollate_batch(batch_data)
print(out[0])
>>> [tensor([[[4.3549e-01...43e-01]]], tensor([[[5.3435e-01...45e-01]]])]

batch_data = torch.rand((2,1,10,10)
out = decollate_batch(batch_data)
print(out[0])
>>> tensor([[[4.3549e-01...43e-01]]])

Args:
data: data to be de-collated.
batch_size: number of batches in data. If `None` is passed, try to figure out batch size.
"""
if not isinstance(data, dict):
raise RuntimeError("Only currently implemented for dictionary data (might be trivial to adapt).")
if batch_size is None:
for v in data.values():
if isinstance(v, torch.Tensor):
batch_size = v.shape[0]
break
if batch_size is None:
raise RuntimeError("Couldn't determine batch size, please specify as argument.")

def torch_to_single(d: torch.Tensor):
"""If input is a torch.Tensor with only 1 element, return just the element."""
Expand All @@ -350,7 +351,27 @@ def decollate(data: Any, idx: int):
return data[idx]
raise TypeError(f"Not sure how to de-collate type: {type(data)}")

return [{key: decollate(data[key], idx) for key in data.keys()} for idx in range(batch_size)]
def _detect_batch_size(batch_data):
for v in batch_data:
if isinstance(v, torch.Tensor):
return v.shape[0]
warnings.warn("batch_data is not a sequence of tensors in decollate, use `len(batch_data[0])` directly.")
return len(batch_data[0])

result: List[Any]
if isinstance(data, dict):
batch_size = _detect_batch_size(batch_data=data.values()) if batch_size is None else batch_size
result = [{key: decollate(data[key], idx) for key in data.keys()} for idx in range(batch_size)]
elif isinstance(data, list):
batch_size = _detect_batch_size(batch_data=data) if batch_size is None else batch_size
result = [[decollate(d, idx) for d in data] for idx in range(batch_size)]
elif isinstance(data, torch.Tensor):
batch_size = data.shape[0]
result = [data[idx] for idx in range(batch_size)]
else:
raise NotImplementedError("Only currently implemented for dictionary, list or Tensor data.")

return result


def pad_list_data_collate(
Expand Down
82 changes: 66 additions & 16 deletions tests/test_decollate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,23 @@
import torch
from parameterized import parameterized

from monai.data import CacheDataset, DataLoader, create_test_image_2d
from monai.data import CacheDataset, DataLoader, Dataset, create_test_image_2d
from monai.data.utils import decollate_batch
from monai.transforms import AddChanneld, Compose, LoadImaged, RandFlipd, SpatialPadd, ToTensord
from monai.transforms import (
AddChannel,
AddChanneld,
Compose,
LoadImage,
LoadImaged,
RandAffine,
RandFlip,
RandFlipd,
RandRotate90,
SpatialPad,
SpatialPadd,
ToTensor,
ToTensord,
)
from monai.transforms.post.dictionary import Decollated
from monai.transforms.spatial.dictionary import RandAffined, RandRotate90d
from monai.utils import optional_import, set_determinism
Expand All @@ -31,18 +45,32 @@

KEYS = ["image"]

TESTS: List[Tuple] = []
TESTS.append((SpatialPadd(KEYS, 150), RandFlipd(KEYS, prob=1.0, spatial_axis=1)))
TESTS.append((RandRotate90d(KEYS, prob=0.0, max_k=1),))
TESTS.append((RandAffined(KEYS, prob=0.0, translate_range=10),))
TESTS_DICT: List[Tuple] = []
TESTS_DICT.append((SpatialPadd(KEYS, 150), RandFlipd(KEYS, prob=1.0, spatial_axis=1)))
TESTS_DICT.append((RandRotate90d(KEYS, prob=0.0, max_k=1),))
TESTS_DICT.append((RandAffined(KEYS, prob=0.0, translate_range=10),))

TESTS_LIST: List[Tuple] = []
TESTS_LIST.append((SpatialPad(150), RandFlip(prob=1.0, spatial_axis=1)))
TESTS_LIST.append((RandRotate90(prob=0.0, max_k=1),))
TESTS_LIST.append((RandAffine(prob=0.0, translate_range=10),))


class _ListCompose(Compose):
def __call__(self, input_):
img, metadata = self.transforms[0](input_)
for t in self.transforms[1:]:
img = t(img)
return img, metadata


class TestDeCollate(unittest.TestCase):
def setUp(self) -> None:
set_determinism(seed=0)

im = create_test_image_2d(100, 101)[0]
self.data = [{"image": make_nifti_image(im) if has_nib else im} for _ in range(6)]
self.data_dict = [{"image": make_nifti_image(im) if has_nib else im} for _ in range(6)]
self.data_list = [make_nifti_image(im) if has_nib else im for _ in range(6)]

def tearDown(self) -> None:
set_determinism(None)
Expand All @@ -68,18 +96,10 @@ def check_match(self, in1, in2):
else:
raise RuntimeError(f"Not sure how to compare types. type(in1): {type(in1)}, type(in2): {type(in2)}")

@parameterized.expand(TESTS)
def test_decollation(self, *transforms):

def check_decollate(self, dataset):
batch_size = 2
num_workers = 2

t_compose = Compose([AddChanneld(KEYS), Compose(transforms), ToTensord(KEYS)])
# If nibabel present, read from disk
if has_nib:
t_compose = Compose([LoadImaged("image"), t_compose])

dataset = CacheDataset(self.data, t_compose, progress=False)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

for b, batch_data in enumerate(loader):
Expand All @@ -90,6 +110,36 @@ def test_decollation(self, *transforms):
for i, d in enumerate(decollated):
self.check_match(dataset[b * batch_size + i], d)

@parameterized.expand(TESTS_DICT)
def test_decollation_dict(self, *transforms):
t_compose = Compose([AddChanneld(KEYS), Compose(transforms), ToTensord(KEYS)])
# If nibabel present, read from disk
if has_nib:
t_compose = Compose([LoadImaged("image"), t_compose])

dataset = CacheDataset(self.data_dict, t_compose, progress=False)
self.check_decollate(dataset=dataset)

@parameterized.expand(TESTS_LIST)
def test_decollation_tensor(self, *transforms):
t_compose = Compose([AddChannel(), Compose(transforms), ToTensor()])
# If nibabel present, read from disk
if has_nib:
t_compose = Compose([LoadImage(image_only=True), t_compose])

dataset = Dataset(self.data_list, t_compose)
self.check_decollate(dataset=dataset)

@parameterized.expand(TESTS_LIST)
def test_decollation_list(self, *transforms):
t_compose = Compose([AddChannel(), Compose(transforms), ToTensor()])
# If nibabel present, read from disk
if has_nib:
t_compose = _ListCompose([LoadImage(image_only=False), t_compose])

dataset = Dataset(self.data_list, t_compose)
self.check_decollate(dataset=dataset)


if __name__ == "__main__":
unittest.main()