Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ def __call__(
self,
spatial_size: Optional[Sequence[int]] = None,
grid: Optional[Union[np.ndarray, torch.Tensor]] = None,
) -> Union[np.ndarray, torch.Tensor]:
) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]:
"""
Args:
spatial_size: output grid size.
Expand All @@ -988,21 +988,20 @@ def __call__(
affine = affine @ create_translate(spatial_dims, self.translate_params)
if self.scale_params:
affine = affine @ create_scale(spatial_dims, self.scale_params)
self.affine = affine
else:
affine = self.affine

self.affine = torch.as_tensor(np.ascontiguousarray(self.affine), device=self.device)
if isinstance(affine, np.ndarray):
affine = torch.as_tensor(np.ascontiguousarray(affine))

grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone()
if self.device:
affine = affine.to(self.device)
grid = grid.to(self.device)
grid = (self.affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:]))
grid = (affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:]))
if grid is None or not isinstance(grid, torch.Tensor):
raise ValueError("Unknown grid.")
return grid if self.as_tensor_output else np.asarray(grid.cpu().numpy())

def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]:
"""Get the most recently applied transformation matrix"""
return self.affine
return grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()), affine


class RandAffineGrid(RandomizableTransform):
Expand Down Expand Up @@ -1094,8 +1093,7 @@ def __call__(
as_tensor_output=self.as_tensor_output,
device=self.device,
)
grid = affine_grid(spatial_size, grid)
self.affine = affine_grid.get_transformation_matrix()
grid, self.affine = affine_grid(spatial_size, grid)
return grid

def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]:
Expand Down Expand Up @@ -1309,7 +1307,7 @@ def __call__(
spatial_size: Optional[Union[Sequence[int], int]] = None,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
) -> Union[np.ndarray, torch.Tensor]:
) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]:
"""
Args:
img: shape must be (num_channels, H, W[, D]),
Expand All @@ -1326,9 +1324,10 @@ def __call__(
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
"""
sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:])
grid = self.affine_grid(spatial_size=sp_size)
return self.resampler(
img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode
grid, affine = self.affine_grid(spatial_size=sp_size)
return (
self.resampler(img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode),
affine,
)


Expand Down Expand Up @@ -1434,7 +1433,6 @@ def __call__(
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
"""
self.randomize()

sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:])
if self._do_transform:
grid = self.rand_affine_grid(spatial_size=sp_size)
Expand Down
7 changes: 3 additions & 4 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,7 @@ def __call__(
d = dict(data)
for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
orig_size = d[key].shape[1:]
d[key] = self.affine(d[key], mode=mode, padding_mode=padding_mode)
affine = self.affine.affine_grid.get_transformation_matrix()
d[key], affine = self.affine(d[key], mode=mode, padding_mode=padding_mode)
self.push_transform(d, key, orig_size=orig_size, extra_info={"affine": affine})
return d

Expand All @@ -588,7 +587,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
inv_affine = np.linalg.inv(fwd_affine)

affine_grid = AffineGrid(affine=inv_affine)
grid: torch.Tensor = affine_grid(orig_size) # type: ignore
grid, _ = affine_grid(orig_size) # type: ignore

# Apply inverse transform
out = self.affine.resampler(d[key], grid, mode, padding_mode)
Expand Down Expand Up @@ -717,7 +716,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
inv_affine = np.linalg.inv(fwd_affine)

affine_grid = AffineGrid(affine=inv_affine)
grid: torch.Tensor = affine_grid(orig_size) # type: ignore
grid, _ = affine_grid(orig_size) # type: ignore

# Apply inverse transform
out = self.rand_affine.resampler(d[key], grid, mode, padding_mode)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class TestAffine(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_affine(self, input_param, input_data, expected_val):
g = Affine(**input_param)
result = g(**input_data)
result, _ = g(**input_data)
self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor))
np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_affine_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class TestAffineGrid(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_affine_grid(self, input_param, input_data, expected_val):
g = AffineGrid(**input_param)
result = g(**input_data)
result, _ = g(**input_data)
self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor))
if isinstance(result, torch.Tensor):
np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4)
Expand Down