Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from monai.config import DtypeLike, KeysCollection, NdarrayTensor
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform
from monai.transforms.transform import MapTransform, Randomizable
from monai.transforms.utility.array import (
AddChannel,
AsChannelFirst,
Expand Down Expand Up @@ -972,14 +972,20 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
return d


class RandTorchVisiond(RandomizableTransform, MapTransform):
class RandTorchVisiond(Randomizable, MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.TorchVision` for randomized transoforms.
Dictionary-based wrapper of :py:class:`monai.transforms.TorchVision` for randomized transforms.
For deterministic non-randomized transforms of TorchVision use :py:class:`monai.transforms.TorchVisiond`.

Note:
As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input
data to be dict of PyTorch Tensors, users can easily call `ToTensord` transform to convert Numpy to Tensor.

- As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input
data to be dict of PyTorch Tensors, users can easily call `ToTensord` transform to convert Numpy to Tensor.
- This class inherits the ``Randomizable`` purely to prevent any dataset caching to skip the transform
computation. If the random factor of the underlying torchvision transform is not derived from `self.R`,
the results may not be deterministic.
See Also: :py:class:`monai.transforms.Randomizable`.

"""

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_rotated.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne
self.segn[0, 0], np.rad2deg(angle), (0, 2), not keep_size, order=0, mode=_mode, prefilter=False
)
expected = np.stack(expected).astype(int)
self.assertLessEqual(np.count_nonzero(expected != rotated["seg"][0]), 110)
self.assertLessEqual(np.count_nonzero(expected != rotated["seg"][0]), 130)


class TestRotated3DXY(NumpyImageTestCase3D):
Expand All @@ -113,7 +113,7 @@ def test_correct_results(self, angle, keep_size, mode, padding_mode, align_corne
self.segn[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=0, mode=_mode, prefilter=False
)
expected = np.stack(expected).astype(int)
self.assertLessEqual(np.count_nonzero(expected != rotated["seg"][0]), 110)
self.assertLessEqual(np.count_nonzero(expected != rotated["seg"][0]), 130)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_threadcontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_plot(self):
with tempfile.TemporaryDirectory() as tempdir:
tempimg = f"{tempdir}/threadcontainer_plot_test.png"
fig.savefig(tempimg)
comp = compare_images(f"{testing_dir}/threadcontainer_plot_test.png", tempimg, 1e-3)
comp = compare_images(f"{testing_dir}/threadcontainer_plot_test.png", tempimg, 1e-2)

self.assertIsNone(comp, comp) # None indicates test passed

Expand Down