diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index de9bba8e95..471b171312 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -961,7 +961,7 @@ def __call__( self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]: """ Args: spatial_size: output grid size. @@ -988,21 +988,20 @@ def __call__( affine = affine @ create_translate(spatial_dims, self.translate_params) if self.scale_params: affine = affine @ create_scale(spatial_dims, self.scale_params) - self.affine = affine + else: + affine = self.affine - self.affine = torch.as_tensor(np.ascontiguousarray(self.affine), device=self.device) + if isinstance(affine, np.ndarray): + affine = torch.as_tensor(np.ascontiguousarray(affine)) grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() if self.device: + affine = affine.to(self.device) grid = grid.to(self.device) - grid = (self.affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) + grid = (affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) if grid is None or not isinstance(grid, torch.Tensor): raise ValueError("Unknown grid.") - return grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()) - - def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]: - """Get the most recently applied transformation matrix""" - return self.affine + return grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()), affine class RandAffineGrid(RandomizableTransform): @@ -1094,8 +1093,7 @@ def __call__( as_tensor_output=self.as_tensor_output, device=self.device, ) - grid = affine_grid(spatial_size, grid) - self.affine = affine_grid.get_transformation_matrix() + grid, self.affine = affine_grid(spatial_size, grid) return grid def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]: @@ -1309,7 +1307,7 @@ def __call__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1326,9 +1324,10 @@ def __call__( See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) - grid = self.affine_grid(spatial_size=sp_size) - return self.resampler( - img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode + grid, affine = self.affine_grid(spatial_size=sp_size) + return ( + self.resampler(img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode), + affine, ) @@ -1434,7 +1433,6 @@ def __call__( See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ self.randomize() - sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) if self._do_transform: grid = self.rand_affine_grid(spatial_size=sp_size) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index e356a51a2a..86c94302a1 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -572,8 +572,7 @@ def __call__( d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): orig_size = d[key].shape[1:] - d[key] = self.affine(d[key], mode=mode, padding_mode=padding_mode) - affine = self.affine.affine_grid.get_transformation_matrix() + d[key], affine = self.affine(d[key], mode=mode, padding_mode=padding_mode) self.push_transform(d, key, orig_size=orig_size, extra_info={"affine": affine}) return d @@ -588,7 +587,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) - grid: torch.Tensor = affine_grid(orig_size) # type: ignore + grid, _ = affine_grid(orig_size) # type: ignore # Apply inverse transform out = self.affine.resampler(d[key], grid, mode, padding_mode) @@ -717,7 +716,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) - grid: torch.Tensor = affine_grid(orig_size) # type: ignore + grid, _ = affine_grid(orig_size) # type: ignore # Apply inverse transform out = self.rand_affine.resampler(d[key], grid, mode, padding_mode) diff --git a/tests/test_affine.py b/tests/test_affine.py index ea146e0fbd..1b6c19596b 100644 --- a/tests/test_affine.py +++ b/tests/test_affine.py @@ -78,7 +78,7 @@ class TestAffine(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_affine(self, input_param, input_data, expected_val): g = Affine(**input_param) - result = g(**input_data) + result, _ = g(**input_data) self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) diff --git a/tests/test_affine_grid.py b/tests/test_affine_grid.py index 2906cd18b6..24772b9a21 100644 --- a/tests/test_affine_grid.py +++ b/tests/test_affine_grid.py @@ -92,7 +92,7 @@ class TestAffineGrid(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_affine_grid(self, input_param, input_data, expected_val): g = AffineGrid(**input_param) - result = g(**input_data) + result, _ = g(**input_data) self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) if isinstance(result, torch.Tensor): np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4)