From ca2faf1379ba6680004b029e4afc3687ac40da30 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 24 Mar 2021 11:35:56 +0800 Subject: [PATCH 01/10] [DLMED] fix affine error and thread-safe issue Signed-off-by: Nic Ma --- monai/transforms/spatial/array.py | 33 ++++++++++---------------- monai/transforms/spatial/dictionary.py | 11 ++++----- tests/test_affine_grid.py | 2 +- tests/test_rand_affine_grid.py | 2 +- 4 files changed, 20 insertions(+), 28 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index de9bba8e95..a50e19971f 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -961,7 +961,7 @@ def __call__( self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]: """ Args: spatial_size: output grid size. @@ -988,21 +988,20 @@ def __call__( affine = affine @ create_translate(spatial_dims, self.translate_params) if self.scale_params: affine = affine @ create_scale(spatial_dims, self.scale_params) - self.affine = affine + else: + affine = self.affine - self.affine = torch.as_tensor(np.ascontiguousarray(self.affine), device=self.device) + if isinstance(affine, np.ndarray): + affine = torch.as_tensor(np.ascontiguousarray(affine)) grid = torch.tensor(grid) if not isinstance(grid, torch.Tensor) else grid.detach().clone() if self.device: + affine = affine.to(self.device) grid = grid.to(self.device) - grid = (self.affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) + grid = (affine.float() @ grid.reshape((grid.shape[0], -1)).float()).reshape([-1] + list(grid.shape[1:])) if grid is None or not isinstance(grid, torch.Tensor): raise ValueError("Unknown grid.") - return grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()) - - def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]: - """Get the most recently applied transformation matrix""" - return self.affine + return grid if self.as_tensor_output else np.asarray(grid.cpu().numpy()), affine class RandAffineGrid(RandomizableTransform): @@ -1094,13 +1093,7 @@ def __call__( as_tensor_output=self.as_tensor_output, device=self.device, ) - grid = affine_grid(spatial_size, grid) - self.affine = affine_grid.get_transformation_matrix() - return grid - - def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]: - """Get the most recently applied transformation matrix""" - return self.affine + return affine_grid(spatial_size, grid) class RandDeformGrid(RandomizableTransform): @@ -1326,7 +1319,7 @@ def __call__( See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) - grid = self.affine_grid(spatial_size=sp_size) + grid, _ = self.affine_grid(spatial_size=sp_size) return self.resampler( img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode ) @@ -1437,7 +1430,7 @@ def __call__( sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) if self._do_transform: - grid = self.rand_affine_grid(spatial_size=sp_size) + grid, _ = self.rand_affine_grid(spatial_size=sp_size) else: grid = create_grid(spatial_size=sp_size) return self.resampler( @@ -1557,7 +1550,7 @@ def __call__( self.randomize(spatial_size=sp_size) if self._do_transform: grid = self.deform_grid(spatial_size=sp_size) - grid = self.rand_affine_grid(grid=grid) + grid, _ = self.rand_affine_grid(grid=grid) grid = torch.nn.functional.interpolate( # type: ignore recompute_scale_factor=True, input=torch.as_tensor(grid).unsqueeze(0), @@ -1690,5 +1683,5 @@ def __call__( gaussian = GaussianFilter(3, self.sigma, 3.0).to(device=self.device) offset = torch.as_tensor(self.rand_offset, device=self.device).unsqueeze(0) grid[:3] += gaussian(offset)[0] * self.magnitude - grid = self.rand_affine_grid(grid=grid) + grid, _ = self.rand_affine_grid(grid=grid) return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 5f8aa3d0b0..c8e11bfa73 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -588,7 +588,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) - grid: torch.Tensor = affine_grid(orig_size) # type: ignore + grid, _ = affine_grid(orig_size) # type: ignore # Apply inverse transform out = self.affine.resampler(d[key], grid, mode, padding_mode) @@ -694,8 +694,7 @@ def __call__( sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) if self._do_transform: - grid = self.rand_affine.rand_affine_grid(spatial_size=sp_size) - affine = self.rand_affine.rand_affine_grid.get_transformation_matrix() + grid, affine = self.rand_affine.rand_affine_grid(spatial_size=sp_size) else: grid = create_grid(spatial_size=sp_size) affine = torch.eye(len(sp_size) + 1) @@ -716,7 +715,7 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) - grid: torch.Tensor = affine_grid(orig_size) # type: ignore + grid, _ = affine_grid(orig_size) # type: ignore # Apply inverse transform out = self.rand_affine.resampler(d[key], grid, mode, padding_mode) @@ -832,7 +831,7 @@ def __call__( if self._do_transform: grid = self.rand_2d_elastic.deform_grid(spatial_size=sp_size) - grid = self.rand_2d_elastic.rand_affine_grid(grid=grid) + grid, _ = self.rand_2d_elastic.rand_affine_grid(grid=grid) grid = torch.nn.functional.interpolate( # type: ignore recompute_scale_factor=True, input=grid.unsqueeze(0), @@ -956,7 +955,7 @@ def __call__( gaussian = GaussianFilter(spatial_dims=3, sigma=self.rand_3d_elastic.sigma, truncated=3.0).to(device) offset = torch.tensor(self.rand_3d_elastic.rand_offset, device=device).unsqueeze(0) grid[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude - grid = self.rand_3d_elastic.rand_affine_grid(grid=grid) + grid, _ = self.rand_3d_elastic.rand_affine_grid(grid=grid) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) diff --git a/tests/test_affine_grid.py b/tests/test_affine_grid.py index 2906cd18b6..24772b9a21 100644 --- a/tests/test_affine_grid.py +++ b/tests/test_affine_grid.py @@ -92,7 +92,7 @@ class TestAffineGrid(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_affine_grid(self, input_param, input_data, expected_val): g = AffineGrid(**input_param) - result = g(**input_data) + result, _ = g(**input_data) self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) if isinstance(result, torch.Tensor): np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) diff --git a/tests/test_rand_affine_grid.py b/tests/test_rand_affine_grid.py index 605d0a30ba..064cdc0621 100644 --- a/tests/test_rand_affine_grid.py +++ b/tests/test_rand_affine_grid.py @@ -186,7 +186,7 @@ class TestRandAffineGrid(unittest.TestCase): def test_rand_affine_grid(self, input_param, input_data, expected_val): g = RandAffineGrid(**input_param) g.set_random_state(123) - result = g(**input_data) + result, _ = g(**input_data) self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) if isinstance(result, torch.Tensor): np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) From 7f78582d15dc9d8f7a9acbe18121deeceefae059 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 24 Mar 2021 17:49:54 +0800 Subject: [PATCH 02/10] [DLMED] update CI tests Signed-off-by: Nic Ma --- monai/transforms/spatial/array.py | 24 +++++++++++++++--------- monai/transforms/spatial/dictionary.py | 3 +-- tests/test_affine.py | 2 +- tests/test_rand_affine.py | 2 +- tests/test_rand_elastic_2d.py | 2 +- tests/test_rand_elastic_3d.py | 2 +- 6 files changed, 20 insertions(+), 15 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index a50e19971f..568fd17652 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1052,7 +1052,6 @@ def __init__( self.as_tensor_output = as_tensor_output self.device = device - self.affine: Optional[Union[np.ndarray, torch.Tensor]] = None def _get_rand_param(self, param_range, add_scalar: float = 0.0): out_param = [] @@ -1319,10 +1318,10 @@ def __call__( See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) - grid, _ = self.affine_grid(spatial_size=sp_size) + grid, affine = self.affine_grid(spatial_size=sp_size) return self.resampler( img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode - ) + ), affine class RandAffine(RandomizableTransform): @@ -1430,12 +1429,13 @@ def __call__( sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) if self._do_transform: - grid, _ = self.rand_affine_grid(spatial_size=sp_size) + grid, affine = self.rand_affine_grid(spatial_size=sp_size) else: grid = create_grid(spatial_size=sp_size) + affine = None return self.resampler( img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode - ) + ), affine class Rand2DElastic(RandomizableTransform): @@ -1550,7 +1550,7 @@ def __call__( self.randomize(spatial_size=sp_size) if self._do_transform: grid = self.deform_grid(spatial_size=sp_size) - grid, _ = self.rand_affine_grid(grid=grid) + grid, affine = self.rand_affine_grid(grid=grid) grid = torch.nn.functional.interpolate( # type: ignore recompute_scale_factor=True, input=torch.as_tensor(grid).unsqueeze(0), @@ -1561,7 +1561,10 @@ def __call__( grid = CenterSpatialCrop(roi_size=sp_size)(np.asarray(grid[0])) else: grid = create_grid(spatial_size=sp_size) - return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) + affine = None + return self.resampler( + img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode + ), affine class Rand3DElastic(RandomizableTransform): @@ -1676,6 +1679,7 @@ def __call__( sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) self.randomize(grid_size=sp_size) grid = create_grid(spatial_size=sp_size) + affine = None if self._do_transform: if self.rand_offset is None: raise AssertionError @@ -1683,5 +1687,7 @@ def __call__( gaussian = GaussianFilter(3, self.sigma, 3.0).to(device=self.device) offset = torch.as_tensor(self.rand_offset, device=self.device).unsqueeze(0) grid[:3] += gaussian(offset)[0] * self.magnitude - grid, _ = self.rand_affine_grid(grid=grid) - return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) + grid, affine = self.rand_affine_grid(grid=grid) + return self.resampler( + img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode + ), affine diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index c8e11bfa73..2548689054 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -572,8 +572,7 @@ def __call__( d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): orig_size = d[key].shape[1:] - d[key] = self.affine(d[key], mode=mode, padding_mode=padding_mode) - affine = self.affine.affine_grid.get_transformation_matrix() + d[key], affine = self.affine(d[key], mode=mode, padding_mode=padding_mode) self.push_transform(d, key, orig_size=orig_size, extra_info={"affine": affine}) return d diff --git a/tests/test_affine.py b/tests/test_affine.py index ea146e0fbd..1b6c19596b 100644 --- a/tests/test_affine.py +++ b/tests/test_affine.py @@ -78,7 +78,7 @@ class TestAffine(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_affine(self, input_param, input_data, expected_val): g = Affine(**input_param) - result = g(**input_data) + result, _ = g(**input_data) self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) diff --git a/tests/test_rand_affine.py b/tests/test_rand_affine.py index 68126f5c8e..ee3c00cffa 100644 --- a/tests/test_rand_affine.py +++ b/tests/test_rand_affine.py @@ -73,7 +73,7 @@ class TestRandAffine(unittest.TestCase): def test_rand_affine(self, input_param, input_data, expected_val): g = RandAffine(**input_param) g.set_random_state(123) - result = g(**input_data) + result, _ = g(**input_data) self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) if isinstance(result, torch.Tensor): np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) diff --git a/tests/test_rand_elastic_2d.py b/tests/test_rand_elastic_2d.py index aa408f0fdc..fc6aeac0ed 100644 --- a/tests/test_rand_elastic_2d.py +++ b/tests/test_rand_elastic_2d.py @@ -94,7 +94,7 @@ class TestRand2DElastic(unittest.TestCase): def test_rand_2d_elastic(self, input_param, input_data, expected_val): g = Rand2DElastic(**input_param) g.set_random_state(123) - result = g(**input_data) + result, _ = g(**input_data) self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) if isinstance(result, torch.Tensor): np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) diff --git a/tests/test_rand_elastic_3d.py b/tests/test_rand_elastic_3d.py index 8cd74c6be7..b9c9dec6e9 100644 --- a/tests/test_rand_elastic_3d.py +++ b/tests/test_rand_elastic_3d.py @@ -73,7 +73,7 @@ class TestRand3DElastic(unittest.TestCase): def test_rand_3d_elastic(self, input_param, input_data, expected_val): g = Rand3DElastic(**input_param) g.set_random_state(123) - result = g(**input_data) + result, _ = g(**input_data) self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) if isinstance(result, torch.Tensor): np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) From 6b1dd86ae118e2644809b6d175f236117d8fb969 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 24 Mar 2021 17:56:37 +0800 Subject: [PATCH 03/10] [DLMED] update typehints Signed-off-by: Nic Ma --- monai/transforms/spatial/array.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 568fd17652..1733b963db 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1074,7 +1074,7 @@ def __call__( self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]: """ Args: spatial_size: output grid size. @@ -1301,7 +1301,7 @@ def __call__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1409,7 +1409,7 @@ def __call__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> Tuple[Union[np.ndarray, torch.Tensor], Optional[Union[np.ndarray, torch.Tensor]]]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1532,7 +1532,7 @@ def __call__( spatial_size: Optional[Union[Tuple[int, int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> Tuple[Union[np.ndarray, torch.Tensor], Optional[Union[np.ndarray, torch.Tensor]]]: """ Args: img: shape must be (num_channels, H, W), @@ -1662,7 +1662,7 @@ def __call__( spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[np.ndarray, torch.Tensor]: + ) -> Tuple[Union[np.ndarray, torch.Tensor], Optional[Union[np.ndarray, torch.Tensor]]]: """ Args: img: shape must be (num_channels, H, W, D), From b02e05be06414df3e35cd3def6b62b8d45c4dd34 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Wed, 24 Mar 2021 10:10:11 +0000 Subject: [PATCH 04/10] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/transforms/spatial/array.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 1733b963db..7dc59c4da1 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1319,9 +1319,10 @@ def __call__( """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) grid, affine = self.affine_grid(spatial_size=sp_size) - return self.resampler( - img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode - ), affine + return ( + self.resampler(img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode), + affine, + ) class RandAffine(RandomizableTransform): @@ -1433,9 +1434,10 @@ def __call__( else: grid = create_grid(spatial_size=sp_size) affine = None - return self.resampler( - img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode - ), affine + return ( + self.resampler(img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode), + affine, + ) class Rand2DElastic(RandomizableTransform): @@ -1562,9 +1564,7 @@ def __call__( else: grid = create_grid(spatial_size=sp_size) affine = None - return self.resampler( - img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode - ), affine + return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode), affine class Rand3DElastic(RandomizableTransform): @@ -1688,6 +1688,4 @@ def __call__( offset = torch.as_tensor(self.rand_offset, device=self.device).unsqueeze(0) grid[:3] += gaussian(offset)[0] * self.magnitude grid, affine = self.rand_affine_grid(grid=grid) - return self.resampler( - img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode - ), affine + return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode), affine From 6937095c869328b6c001c25da1407038de940765 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 24 Mar 2021 18:38:23 +0800 Subject: [PATCH 05/10] [DLMED] fix flake8 Signed-off-by: Nic Ma --- monai/transforms/spatial/array.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 7dc59c4da1..228ce65545 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1427,13 +1427,12 @@ def __call__( See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ self.randomize() - + affine = None sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) if self._do_transform: grid, affine = self.rand_affine_grid(spatial_size=sp_size) else: grid = create_grid(spatial_size=sp_size) - affine = None return ( self.resampler(img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode), affine, @@ -1550,6 +1549,7 @@ def __call__( """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) self.randomize(spatial_size=sp_size) + affine = None if self._do_transform: grid = self.deform_grid(spatial_size=sp_size) grid, affine = self.rand_affine_grid(grid=grid) @@ -1563,7 +1563,6 @@ def __call__( grid = CenterSpatialCrop(roi_size=sp_size)(np.asarray(grid[0])) else: grid = create_grid(spatial_size=sp_size) - affine = None return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode), affine From c165e71481c4632ffa58194fdf8548d6cd9378af Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 24 Mar 2021 22:34:27 +0800 Subject: [PATCH 06/10] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/transforms/spatial/array.py | 34 ++++++++++++++------------ monai/transforms/spatial/dictionary.py | 7 +++--- tests/test_rand_affine.py | 2 +- tests/test_rand_affine_grid.py | 2 +- tests/test_rand_elastic_2d.py | 2 +- tests/test_rand_elastic_3d.py | 2 +- 6 files changed, 26 insertions(+), 23 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 228ce65545..471b171312 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1052,6 +1052,7 @@ def __init__( self.as_tensor_output = as_tensor_output self.device = device + self.affine: Optional[Union[np.ndarray, torch.Tensor]] = None def _get_rand_param(self, param_range, add_scalar: float = 0.0): out_param = [] @@ -1074,7 +1075,7 @@ def __call__( self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[Union[np.ndarray, torch.Tensor]] = None, - ) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]: + ) -> Union[np.ndarray, torch.Tensor]: """ Args: spatial_size: output grid size. @@ -1092,7 +1093,12 @@ def __call__( as_tensor_output=self.as_tensor_output, device=self.device, ) - return affine_grid(spatial_size, grid) + grid, self.affine = affine_grid(spatial_size, grid) + return grid + + def get_transformation_matrix(self) -> Optional[Union[np.ndarray, torch.Tensor]]: + """Get the most recently applied transformation matrix""" + return self.affine class RandDeformGrid(RandomizableTransform): @@ -1410,7 +1416,7 @@ def __call__( spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Tuple[Union[np.ndarray, torch.Tensor], Optional[Union[np.ndarray, torch.Tensor]]]: + ) -> Union[np.ndarray, torch.Tensor]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1427,15 +1433,13 @@ def __call__( See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample """ self.randomize() - affine = None sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) if self._do_transform: - grid, affine = self.rand_affine_grid(spatial_size=sp_size) + grid = self.rand_affine_grid(spatial_size=sp_size) else: grid = create_grid(spatial_size=sp_size) - return ( - self.resampler(img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode), - affine, + return self.resampler( + img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode ) @@ -1533,7 +1537,7 @@ def __call__( spatial_size: Optional[Union[Tuple[int, int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Tuple[Union[np.ndarray, torch.Tensor], Optional[Union[np.ndarray, torch.Tensor]]]: + ) -> Union[np.ndarray, torch.Tensor]: """ Args: img: shape must be (num_channels, H, W), @@ -1549,10 +1553,9 @@ def __call__( """ sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) self.randomize(spatial_size=sp_size) - affine = None if self._do_transform: grid = self.deform_grid(spatial_size=sp_size) - grid, affine = self.rand_affine_grid(grid=grid) + grid = self.rand_affine_grid(grid=grid) grid = torch.nn.functional.interpolate( # type: ignore recompute_scale_factor=True, input=torch.as_tensor(grid).unsqueeze(0), @@ -1563,7 +1566,7 @@ def __call__( grid = CenterSpatialCrop(roi_size=sp_size)(np.asarray(grid[0])) else: grid = create_grid(spatial_size=sp_size) - return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode), affine + return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) class Rand3DElastic(RandomizableTransform): @@ -1661,7 +1664,7 @@ def __call__( spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Tuple[Union[np.ndarray, torch.Tensor], Optional[Union[np.ndarray, torch.Tensor]]]: + ) -> Union[np.ndarray, torch.Tensor]: """ Args: img: shape must be (num_channels, H, W, D), @@ -1678,7 +1681,6 @@ def __call__( sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:]) self.randomize(grid_size=sp_size) grid = create_grid(spatial_size=sp_size) - affine = None if self._do_transform: if self.rand_offset is None: raise AssertionError @@ -1686,5 +1688,5 @@ def __call__( gaussian = GaussianFilter(3, self.sigma, 3.0).to(device=self.device) offset = torch.as_tensor(self.rand_offset, device=self.device).unsqueeze(0) grid[:3] += gaussian(offset)[0] * self.magnitude - grid, affine = self.rand_affine_grid(grid=grid) - return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode), affine + grid = self.rand_affine_grid(grid=grid) + return self.resampler(img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 2548689054..527a44b54f 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -693,7 +693,8 @@ def __call__( sp_size = fall_back_tuple(self.rand_affine.spatial_size, data[self.keys[0]].shape[1:]) if self._do_transform: - grid, affine = self.rand_affine.rand_affine_grid(spatial_size=sp_size) + grid = self.rand_affine.rand_affine_grid(spatial_size=sp_size) + affine = self.rand_affine.rand_affine_grid.get_transformation_matrix() else: grid = create_grid(spatial_size=sp_size) affine = torch.eye(len(sp_size) + 1) @@ -830,7 +831,7 @@ def __call__( if self._do_transform: grid = self.rand_2d_elastic.deform_grid(spatial_size=sp_size) - grid, _ = self.rand_2d_elastic.rand_affine_grid(grid=grid) + grid = self.rand_2d_elastic.rand_affine_grid(grid=grid) grid = torch.nn.functional.interpolate( # type: ignore recompute_scale_factor=True, input=grid.unsqueeze(0), @@ -954,7 +955,7 @@ def __call__( gaussian = GaussianFilter(spatial_dims=3, sigma=self.rand_3d_elastic.sigma, truncated=3.0).to(device) offset = torch.tensor(self.rand_3d_elastic.rand_offset, device=device).unsqueeze(0) grid[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude - grid, _ = self.rand_3d_elastic.rand_affine_grid(grid=grid) + grid = self.rand_3d_elastic.rand_affine_grid(grid=grid) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) diff --git a/tests/test_rand_affine.py b/tests/test_rand_affine.py index ee3c00cffa..68126f5c8e 100644 --- a/tests/test_rand_affine.py +++ b/tests/test_rand_affine.py @@ -73,7 +73,7 @@ class TestRandAffine(unittest.TestCase): def test_rand_affine(self, input_param, input_data, expected_val): g = RandAffine(**input_param) g.set_random_state(123) - result, _ = g(**input_data) + result = g(**input_data) self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) if isinstance(result, torch.Tensor): np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) diff --git a/tests/test_rand_affine_grid.py b/tests/test_rand_affine_grid.py index 064cdc0621..605d0a30ba 100644 --- a/tests/test_rand_affine_grid.py +++ b/tests/test_rand_affine_grid.py @@ -186,7 +186,7 @@ class TestRandAffineGrid(unittest.TestCase): def test_rand_affine_grid(self, input_param, input_data, expected_val): g = RandAffineGrid(**input_param) g.set_random_state(123) - result, _ = g(**input_data) + result = g(**input_data) self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) if isinstance(result, torch.Tensor): np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) diff --git a/tests/test_rand_elastic_2d.py b/tests/test_rand_elastic_2d.py index fc6aeac0ed..aa408f0fdc 100644 --- a/tests/test_rand_elastic_2d.py +++ b/tests/test_rand_elastic_2d.py @@ -94,7 +94,7 @@ class TestRand2DElastic(unittest.TestCase): def test_rand_2d_elastic(self, input_param, input_data, expected_val): g = Rand2DElastic(**input_param) g.set_random_state(123) - result, _ = g(**input_data) + result = g(**input_data) self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) if isinstance(result, torch.Tensor): np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) diff --git a/tests/test_rand_elastic_3d.py b/tests/test_rand_elastic_3d.py index b9c9dec6e9..8cd74c6be7 100644 --- a/tests/test_rand_elastic_3d.py +++ b/tests/test_rand_elastic_3d.py @@ -73,7 +73,7 @@ class TestRand3DElastic(unittest.TestCase): def test_rand_3d_elastic(self, input_param, input_data, expected_val): g = Rand3DElastic(**input_param) g.set_random_state(123) - result, _ = g(**input_data) + result = g(**input_data) self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) if isinstance(result, torch.Tensor): np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) From d029fc3bbadc5e784fde76cc717abd16695db1d1 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 25 Mar 2021 00:32:43 +0800 Subject: [PATCH 07/10] [DLMED] make cachedataset to be thread-safe Signed-off-by: Nic Ma --- monai/data/dataset.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 813008e3a8..93525936e3 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -520,13 +520,15 @@ def _load_cache_item(self, idx: int): idx: the index of the input data sequence. """ item = self.data[idx] - if not isinstance(self.transform, Compose): + # copy to make sure thread-safe + transform_ = deepcopy(self.transform) + if not isinstance(transform_, Compose): raise ValueError("transform must be an instance of monai.transforms.Compose.") - for _transform in self.transform.transforms: + for _trans in transform_.transforms: # execute all the deterministic transforms - if isinstance(_transform, RandomizableTransform) or not isinstance(_transform, Transform): + if isinstance(_trans, RandomizableTransform) or not isinstance(_trans, Transform): break - item = apply_transform(_transform, item) + item = apply_transform(_trans, item) return item def __getitem__(self, index): From 1066f4e7d1efb090a421700cf9a8aacf8e47d689 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 25 Mar 2021 00:53:13 +0800 Subject: [PATCH 08/10] [DLMED] remove inverse ID check Signed-off-by: Nic Ma --- monai/transforms/inverse.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 3e5b68e8e4..aba58ab0dd 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -89,14 +89,8 @@ def push_transform( data[key_transform].append(info) def check_transforms_match(self, transform: dict) -> None: - """Check transforms are of same instance.""" - if transform[InverseKeys.ID] == id(self): - return - # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID) - if ( - torch.multiprocessing.get_start_method(allow_none=False) == "spawn" - and transform[InverseKeys.CLASS_NAME] == self.__class__.__name__ - ): + """Check whether match the transform class.""" + if transform[InverseKeys.CLASS_NAME] == self.__class__.__name__: return raise RuntimeError("Should inverse most recently applied invertible transform first") From 4974efbc556b6d2958647f9533b24c1d9360d5ff Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 25 Mar 2021 01:08:48 +0800 Subject: [PATCH 09/10] [DLMED] fix flake8 issue Signed-off-by: Nic Ma --- monai/transforms/inverse.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index aba58ab0dd..cb5c67100c 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -12,7 +12,6 @@ from typing import Dict, Hashable, Optional, Tuple import numpy as np -import torch from monai.transforms.transform import RandomizableTransform, Transform from monai.utils.enums import InverseKeys From a4ca02eafc5813713d21fb274e88c72a41199673 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 25 Mar 2021 21:46:17 +0800 Subject: [PATCH 10/10] [DLMED] restore CacheDataset and inverse transform Signed-off-by: Nic Ma --- monai/data/dataset.py | 10 ++++------ monai/transforms/inverse.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 93525936e3..813008e3a8 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -520,15 +520,13 @@ def _load_cache_item(self, idx: int): idx: the index of the input data sequence. """ item = self.data[idx] - # copy to make sure thread-safe - transform_ = deepcopy(self.transform) - if not isinstance(transform_, Compose): + if not isinstance(self.transform, Compose): raise ValueError("transform must be an instance of monai.transforms.Compose.") - for _trans in transform_.transforms: + for _transform in self.transform.transforms: # execute all the deterministic transforms - if isinstance(_trans, RandomizableTransform) or not isinstance(_trans, Transform): + if isinstance(_transform, RandomizableTransform) or not isinstance(_transform, Transform): break - item = apply_transform(_trans, item) + item = apply_transform(_transform, item) return item def __getitem__(self, index): diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index cb5c67100c..3e5b68e8e4 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -12,6 +12,7 @@ from typing import Dict, Hashable, Optional, Tuple import numpy as np +import torch from monai.transforms.transform import RandomizableTransform, Transform from monai.utils.enums import InverseKeys @@ -88,8 +89,14 @@ def push_transform( data[key_transform].append(info) def check_transforms_match(self, transform: dict) -> None: - """Check whether match the transform class.""" - if transform[InverseKeys.CLASS_NAME] == self.__class__.__name__: + """Check transforms are of same instance.""" + if transform[InverseKeys.ID] == id(self): + return + # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID) + if ( + torch.multiprocessing.get_start_method(allow_none=False) == "spawn" + and transform[InverseKeys.CLASS_NAME] == self.__class__.__name__ + ): return raise RuntimeError("Should inverse most recently applied invertible transform first")