Skip to content
6 changes: 3 additions & 3 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,9 +645,9 @@ def __init__(

def set_random_state(
self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None
) -> "Randomizable":
super().set_random_state(seed=seed, state=state)
self.cropper.set_random_state(state=self.R)
) -> "RandSpatialCropSamples":
super().set_random_state(seed, state)
self.cropper.set_random_state(seed, state)
return self

def randomize(self, data: Optional[Any] = None) -> None:
Expand Down
6 changes: 3 additions & 3 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,9 +756,9 @@ def __init__(

def set_random_state(
self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None
) -> "Randomizable":
super().set_random_state(seed=seed, state=state)
self.cropper.set_random_state(state=self.R)
) -> "RandSpatialCropSamplesd":
super().set_random_state(seed, state)
self.cropper.set_random_state(seed, state)
return self

def randomize(self, data: Optional[Any] = None) -> None:
Expand Down
33 changes: 17 additions & 16 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from copy import deepcopy
from enum import Enum
from typing import Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand All @@ -40,7 +40,6 @@
RandAxisFlip,
RandFlip,
RandRotate,
RandRotate90,
RandZoom,
Resize,
Rotate,
Expand Down Expand Up @@ -438,7 +437,7 @@ class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform):
in the plane specified by `spatial_axes`.
"""

backend = RandRotate90.backend
backend = Rotate90.backend

def __init__(
self,
Expand All @@ -462,25 +461,27 @@ def __init__(
"""
MapTransform.__init__(self, keys, allow_missing_keys)
RandomizableTransform.__init__(self, prob)
self.rand_rotate90 = RandRotate90(prob=1.0, max_k=max_k, spatial_axes=spatial_axes)

def set_random_state(
self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None
) -> "RandRotate90d":
super().set_random_state(seed, state)
self.rand_rotate90.set_random_state(seed, state)
return self
self.max_k = max_k
self.spatial_axes = spatial_axes

self._rand_k = 0

def randomize(self, data: Optional[Any] = None) -> None:
self._rand_k = self.R.randint(self.max_k) + 1
super().randomize(None)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
self.randomize()
d = dict(data)
self.randomize(None)

# all the keys share the same random factor
self.rand_rotate90.randomize()
# FIXME: here we didn't use array version `RandRotate90` transform as others, because we need
# to be compatible with the random status of some previous integration tests
rotator = Rotate90(self._rand_k, self.spatial_axes)
for key in self.key_iterator(d):
if self._do_transform:
d[key] = self.rand_rotate90(d[key], randomize=False)
self.push_transform(d, key, extra_info={"rand_k": self.rand_rotate90._rand_k})
d[key] = rotator(d[key])
self.push_transform(d, key, extra_info={"rand_k": self._rand_k})
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
Expand All @@ -492,7 +493,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
# Create inverse transform
num_times_rotated = transform[InverseKeys.EXTRA_INFO]["rand_k"]
num_times_to_rotate = 4 - num_times_rotated
inverse_transform = Rotate90(num_times_to_rotate, self.rand_rotate90.spatial_axes)
inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes)
# Apply inverse
d[key] = inverse_transform(d[key])
# Remove the applied transform
Expand Down
4 changes: 2 additions & 2 deletions tests/test_rand_rotate90d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_k(self):
key = "test"
rotate = RandRotate90d(keys=key, max_k=2)
for p in TEST_NDARRAYS:
rotate.set_random_state(123)
rotate.set_random_state(234)
rotated = rotate({key: p(self.imt[0])})
expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]]
expected = np.stack(expected)
Expand All @@ -42,7 +42,7 @@ def test_spatial_axes(self):
key = "test"
rotate = RandRotate90d(keys=key, spatial_axes=(0, 1))
for p in TEST_NDARRAYS:
rotate.set_random_state(123)
rotate.set_random_state(234)
rotated = rotate({key: p(self.imt[0])})
expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]]
expected = np.stack(expected)
Expand Down
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def is_tf32_env():
if (
torch.cuda.is_available()
and not version_leq(f"{torch.version.cuda}", "10.100")
and os.environ.get("NVIDIA_TF32_OVERRIDE", "1") != "0" # at least 11.0
and torch.cuda.device_count() > 0
and os.environ.get("NVIDIA_TF32_OVERRIDE", "1") != "0"
and torch.cuda.device_count() > 0 # at least 11.0
):
try:
# with TF32 enabled, the speed is ~8x faster, but the precision has ~2 digits less in the result
Expand Down