diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index b3ab865c52..ce0105bbe1 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -645,9 +645,9 @@ def __init__( def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "Randomizable": - super().set_random_state(seed=seed, state=state) - self.cropper.set_random_state(state=self.R) + ) -> "RandSpatialCropSamples": + super().set_random_state(seed, state) + self.cropper.set_random_state(seed, state) return self def randomize(self, data: Optional[Any] = None) -> None: diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index eea00a6c4b..bf0b9ef04d 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -756,9 +756,9 @@ def __init__( def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "Randomizable": - super().set_random_state(seed=seed, state=state) - self.cropper.set_random_state(state=self.R) + ) -> "RandSpatialCropSamplesd": + super().set_random_state(seed, state) + self.cropper.set_random_state(seed, state) return self def randomize(self, data: Optional[Any] = None) -> None: diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index a136f174bb..061424f523 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -17,7 +17,7 @@ from copy import deepcopy from enum import Enum -from typing import Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -40,7 +40,6 @@ RandAxisFlip, RandFlip, RandRotate, - RandRotate90, RandZoom, Resize, Rotate, @@ -438,7 +437,7 @@ class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform): in the plane specified by `spatial_axes`. """ - backend = RandRotate90.backend + backend = Rotate90.backend def __init__( self, @@ -462,25 +461,27 @@ def __init__( """ MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.rand_rotate90 = RandRotate90(prob=1.0, max_k=max_k, spatial_axes=spatial_axes) - def set_random_state( - self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None - ) -> "RandRotate90d": - super().set_random_state(seed, state) - self.rand_rotate90.set_random_state(seed, state) - return self + self.max_k = max_k + self.spatial_axes = spatial_axes + + self._rand_k = 0 + + def randomize(self, data: Optional[Any] = None) -> None: + self._rand_k = self.R.randint(self.max_k) + 1 + super().randomize(None) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + self.randomize() d = dict(data) - self.randomize(None) - # all the keys share the same random factor - self.rand_rotate90.randomize() + # FIXME: here we didn't use array version `RandRotate90` transform as others, because we need + # to be compatible with the random status of some previous integration tests + rotator = Rotate90(self._rand_k, self.spatial_axes) for key in self.key_iterator(d): if self._do_transform: - d[key] = self.rand_rotate90(d[key], randomize=False) - self.push_transform(d, key, extra_info={"rand_k": self.rand_rotate90._rand_k}) + d[key] = rotator(d[key]) + self.push_transform(d, key, extra_info={"rand_k": self._rand_k}) return d def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: @@ -492,7 +493,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd # Create inverse transform num_times_rotated = transform[InverseKeys.EXTRA_INFO]["rand_k"] num_times_to_rotate = 4 - num_times_rotated - inverse_transform = Rotate90(num_times_to_rotate, self.rand_rotate90.spatial_axes) + inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) # Apply inverse d[key] = inverse_transform(d[key]) # Remove the applied transform diff --git a/tests/test_rand_rotate90d.py b/tests/test_rand_rotate90d.py index a62b98163a..3071aa82c8 100644 --- a/tests/test_rand_rotate90d.py +++ b/tests/test_rand_rotate90d.py @@ -32,7 +32,7 @@ def test_k(self): key = "test" rotate = RandRotate90d(keys=key, max_k=2) for p in TEST_NDARRAYS: - rotate.set_random_state(123) + rotate.set_random_state(234) rotated = rotate({key: p(self.imt[0])}) expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) @@ -42,7 +42,7 @@ def test_spatial_axes(self): key = "test" rotate = RandRotate90d(keys=key, spatial_axes=(0, 1)) for p in TEST_NDARRAYS: - rotate.set_random_state(123) + rotate.set_random_state(234) rotated = rotate({key: p(self.imt[0])}) expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) diff --git a/tests/utils.py b/tests/utils.py index 49d5a13af3..8b18b8315d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -116,8 +116,8 @@ def is_tf32_env(): if ( torch.cuda.is_available() and not version_leq(f"{torch.version.cuda}", "10.100") - and os.environ.get("NVIDIA_TF32_OVERRIDE", "1") != "0" # at least 11.0 - and torch.cuda.device_count() > 0 + and os.environ.get("NVIDIA_TF32_OVERRIDE", "1") != "0" + and torch.cuda.device_count() > 0 # at least 11.0 ): try: # with TF32 enabled, the speed is ~8x faster, but the precision has ~2 digits less in the result