From f220f35cc66443f6d1c45dbf774880d5ae47501b Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 21 Apr 2021 08:00:20 +0800 Subject: [PATCH 1/2] [DLMED] check the label data for Tensor Signed-off-by: Nic Ma --- monai/engines/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/engines/utils.py b/monai/engines/utils.py index d16ab3cfbb..670a65b611 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -107,7 +107,7 @@ def default_prepare_batch( """ if not isinstance(batchdata, dict): raise AssertionError("default prepare_batch expects dictionary input data.") - if CommonKeys.LABEL in batchdata: + if torch.is_tensor(batchdata.get(CommonKeys.LABEL, None)): return ( batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking), batchdata[CommonKeys.LABEL].to(device=device, non_blocking=non_blocking), From dc2ff07123676cd1ce478e0ec4acc4d61efe8b68 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 21 Apr 2021 08:27:20 +0800 Subject: [PATCH 2/2] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/engines/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 670a65b611..265a63ee0c 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -107,7 +107,7 @@ def default_prepare_batch( """ if not isinstance(batchdata, dict): raise AssertionError("default prepare_batch expects dictionary input data.") - if torch.is_tensor(batchdata.get(CommonKeys.LABEL, None)): + if isinstance(batchdata.get(CommonKeys.LABEL, None), torch.Tensor): return ( batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking), batchdata[CommonKeys.LABEL].to(device=device, non_blocking=non_blocking),