Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,16 +331,18 @@ def decollate_batch(data: Union[dict, list, torch.Tensor], batch_size: Optional[
"""

def torch_to_single(d: torch.Tensor):
"""If input is a torch.Tensor with only 1 element, return just the element."""
return d if d.numel() > 1 else d.item()
"""If input is a torch.Tensor with only 1 element, return just the element,
otherwise, detach the Tensor from graph and return.

"""
return d.detach() if d.numel() > 1 else d.item()

def decollate(data: Any, idx: int):
"""Recursively de-collate."""
if isinstance(data, dict):
return {k: decollate(v, idx) for k, v in data.items()}
if isinstance(data, torch.Tensor):
out = data[idx]
return torch_to_single(out)
if isinstance(data, torch.Tensor) and data.ndim > 0:
return torch_to_single(data[idx])
if isinstance(data, list):
if len(data) == 0:
return data
Expand All @@ -353,10 +355,13 @@ def decollate(data: Any, idx: int):

def _detect_batch_size(batch_data):
for v in batch_data:
if isinstance(v, torch.Tensor):
if isinstance(v, torch.Tensor) and v.ndim > 0:
return v.shape[0]
warnings.warn("batch_data is not a sequence of tensors in decollate, use `len(batch_data[0])` directly.")
return len(batch_data[0])
for v in batch_data:
if issequenceiterable(v):
warnings.warn("batch_data doesn't contain batched Tensor data, use the length of first sequence data.")
return len(v)
raise RuntimeError("failed to automatically detect the batch size.")

result: List[Any]
if isinstance(data, dict):
Expand All @@ -365,9 +370,9 @@ def _detect_batch_size(batch_data):
elif isinstance(data, list):
batch_size = _detect_batch_size(batch_data=data) if batch_size is None else batch_size
result = [[decollate(d, idx) for d in data] for idx in range(batch_size)]
elif isinstance(data, torch.Tensor):
elif isinstance(data, torch.Tensor) and data.ndim > 0:
batch_size = data.shape[0]
result = [data[idx] for idx in range(batch_size)]
result = [torch_to_single(data[idx]) for idx in range(batch_size)]
else:
raise NotImplementedError("Only currently implemented for dictionary, list or Tensor data.")

Expand Down