diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index fd7fabe5a1..7d3d118196 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -23,6 +23,7 @@ from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np +import torch from monai.config import IndexSelection, KeysCollection from monai.config.type_definitions import NdarrayOrTensor @@ -315,7 +316,10 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) + orig_size = transform[TraceKeys.ORIG_SIZE] + if isinstance(orig_size[0], torch.Tensor): + orig_size = torch.as_tensor(orig_size) + orig_size = np.asarray(orig_size) current_size = np.array(d[key].shape[1:]) roi_start = np.floor((current_size - orig_size) / 2) roi_end = orig_size + roi_start