diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 5f8aa3d0b0..e356a51a2a 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -698,7 +698,8 @@ def __call__( affine = self.rand_affine.rand_affine_grid.get_transformation_matrix() else: grid = create_grid(spatial_size=sp_size) - affine = torch.eye(len(sp_size) + 1) + # to be consistent with the self._do_transform case (dtype and device) + affine = torch.as_tensor(np.eye(len(sp_size) + 1), device=self.rand_affine.rand_affine_grid.device) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): self.push_transform(d, key, extra_info={"affine": affine}) diff --git a/tests/test_inverse_collation.py b/tests/test_inverse_collation.py index 5bde157343..3e07a8f0e2 100644 --- a/tests/test_inverse_collation.py +++ b/tests/test_inverse_collation.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING import numpy as np +import torch from parameterized import parameterized from monai.data import CacheDataset, DataLoader, create_test_image_2d, create_test_image_3d, pad_list_data_collate @@ -49,7 +50,13 @@ RandRotate90d(keys=KEYS, spatial_axes=(1, 2)), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(keys=KEYS, prob=0.5, range_x=np.pi), - RandAffined(keys=KEYS, prob=0.5, rotate_range=np.pi), + RandAffined( + keys=KEYS, + prob=0.5, + rotate_range=np.pi, + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + as_tensor_output=False, + ), ] ] @@ -62,7 +69,13 @@ RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0, 1)), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(keys=KEYS, prob=0.5, range_x=np.pi), - RandAffined(keys=KEYS, prob=0.5, rotate_range=np.pi), + RandAffined( + keys=KEYS, + prob=0.5, + rotate_range=np.pi, + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + as_tensor_output=False, + ), ] ] @@ -102,8 +115,8 @@ def test_collation(self, _, transform, collate_fn, ndim): else: modified_transform = Compose([transform, ResizeWithPadOrCropd(KEYS, 100)]) - # num workers = 0 for mac - num_workers = 2 if sys.platform != "darwin" else 0 + # num workers = 0 for mac or gpu transforms + num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2 dataset = CacheDataset(data, transform=modified_transform, progress=False) loader = DataLoader(dataset, num_workers, batch_size=self.batch_size, collate_fn=collate_fn)