From 4bd3d05ae6e2cf9cb25c1dafc7646428ce79b72c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 16 Jun 2021 11:31:59 +0800 Subject: [PATCH] [DLMED] enhance decollate_batch Signed-off-by: Nic Ma --- monai/data/utils.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 49f64df1b9..2d10b0d003 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -331,16 +331,18 @@ def decollate_batch(data: Union[dict, list, torch.Tensor], batch_size: Optional[ """ def torch_to_single(d: torch.Tensor): - """If input is a torch.Tensor with only 1 element, return just the element.""" - return d if d.numel() > 1 else d.item() + """If input is a torch.Tensor with only 1 element, return just the element, + otherwise, detach the Tensor from graph and return. + + """ + return d.detach() if d.numel() > 1 else d.item() def decollate(data: Any, idx: int): """Recursively de-collate.""" if isinstance(data, dict): return {k: decollate(v, idx) for k, v in data.items()} - if isinstance(data, torch.Tensor): - out = data[idx] - return torch_to_single(out) + if isinstance(data, torch.Tensor) and data.ndim > 0: + return torch_to_single(data[idx]) if isinstance(data, list): if len(data) == 0: return data @@ -353,10 +355,13 @@ def decollate(data: Any, idx: int): def _detect_batch_size(batch_data): for v in batch_data: - if isinstance(v, torch.Tensor): + if isinstance(v, torch.Tensor) and v.ndim > 0: return v.shape[0] - warnings.warn("batch_data is not a sequence of tensors in decollate, use `len(batch_data[0])` directly.") - return len(batch_data[0]) + for v in batch_data: + if issequenceiterable(v): + warnings.warn("batch_data doesn't contain batched Tensor data, use the length of first sequence data.") + return len(v) + raise RuntimeError("failed to automatically detect the batch size.") result: List[Any] if isinstance(data, dict): @@ -365,9 +370,9 @@ def _detect_batch_size(batch_data): elif isinstance(data, list): batch_size = _detect_batch_size(batch_data=data) if batch_size is None else batch_size result = [[decollate(d, idx) for d in data] for idx in range(batch_size)] - elif isinstance(data, torch.Tensor): + elif isinstance(data, torch.Tensor) and data.ndim > 0: batch_size = data.shape[0] - result = [data[idx] for idx in range(batch_size)] + result = [torch_to_single(data[idx]) for idx in range(batch_size)] else: raise NotImplementedError("Only currently implemented for dictionary, list or Tensor data.")