diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 2e389d9e0b..ffad8dba88 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -187,23 +187,19 @@ def __init__( ) -> None: super().__init__(data=data, transform=None) self.patch_iter = patch_iter - self.transform = transform + self.patch_transform = transform self.with_coordinates = with_coordinates def __iter__(self): for image in super().__iter__(): - if not self.with_coordinates: - for patch, *_ in self.patch_iter(image): # patch_iter to yield at least 1 item: patch - out_patch = ( - patch if self.transform is None else apply_transform(self.transform, patch, map_items=False) - ) + for patch, *others in self.patch_iter(image): + out_patch = patch + if self.patch_transform is not None: + out_patch = apply_transform(self.patch_transform, patch, map_items=False) + if self.with_coordinates and len(others) > 0: # patch_iter to yield at least 2 items: patch, coords + yield out_patch, others[0] + else: yield out_patch - else: - for patch, slices, *_ in self.patch_iter(image): # patch_iter to yield at least 2 items: patch, coords - out_patch = ( - patch if self.transform is None else apply_transform(self.transform, patch, map_items=False) - ) - yield out_patch, slices class PatchDataset(Dataset): diff --git a/monai/data/iterable_dataset.py b/monai/data/iterable_dataset.py index f292bf1593..f1906a80fe 100644 --- a/monai/data/iterable_dataset.py +++ b/monai/data/iterable_dataset.py @@ -11,7 +11,6 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union -import numpy as np from torch.utils.data import IterableDataset as _TorchIterableDataset from torch.utils.data import get_worker_info @@ -115,9 +114,6 @@ def _get_item(): def randomize(self, size: int) -> None: self._idx = self.R.randint(size) - def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None): - raise NotImplementedError(f"`set_random_state` is not available in {self.__class__.__name__}.") - class CSVIterableDataset(IterableDataset): """ diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index 529e679142..30680c8e31 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -60,7 +60,7 @@ def test_loading_array(self): np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) np.testing.assert_allclose( item[0], - np.array([[[[7.4965, 8.4965], [11.4965, 12.4965]]], [[[11.3584, 12.3584], [15.3584, 16.3584]]]]), + np.array([[[[8.0577, 9.0577], [12.0577, 13.0577]]], [[[10.5540, 11.5540], [14.5540, 15.5540]]]]), rtol=1e-4, ) np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) @@ -69,7 +69,7 @@ def test_loading_array(self): np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) np.testing.assert_allclose( item[0], - np.array([[[[7.2548, 8.2548], [11.2548, 12.2548]]], [[[9.1106, 10.1106], [13.1106, 14.1106]]]]), + np.array([[[[7.6533, 8.6533], [11.6533, 12.6533]]], [[[9.8524, 10.8524], [13.8524, 14.8524]]]]), rtol=1e-3, ) np.testing.assert_allclose( @@ -102,7 +102,7 @@ def test_loading_dict(self): self.assertListEqual(item[0]["metadata"], ["test string", "test string"]) np.testing.assert_allclose( item[0]["image"], - np.array([[[[7.4965, 8.4965], [11.4965, 12.4965]]], [[[11.3584, 12.3584], [15.3584, 16.3584]]]]), + np.array([[[[8.0577, 9.0577], [12.0577, 13.0577]]], [[[10.5540, 11.5540], [14.5540, 15.5540]]]]), rtol=1e-4, ) np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) @@ -111,7 +111,7 @@ def test_loading_dict(self): np.testing.assert_equal(item[0]["image"].shape, (2, 1, 2, 2)) np.testing.assert_allclose( item[0]["image"], - np.array([[[[7.2548, 8.2548], [11.2548, 12.2548]]], [[[9.1106, 10.1106], [13.1106, 14.1106]]]]), + np.array([[[[7.6533, 8.6533], [11.6533, 12.6533]]], [[[9.8524, 10.8524], [13.8524, 14.8524]]]]), rtol=1e-3, ) np.testing.assert_allclose( diff --git a/tests/test_shuffle_buffer.py b/tests/test_shuffle_buffer.py new file mode 100644 index 0000000000..40012fbf93 --- /dev/null +++ b/tests/test_shuffle_buffer.py @@ -0,0 +1,34 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np + +from monai.data import DataLoader, ShuffleBuffer +from monai.utils import convert_data_type + + +class TestShuffleBuffer(unittest.TestCase): + def test_shape(self): + buffer = ShuffleBuffer([1, 2, 3, 4], seed=0) + num_workers = 2 if sys.platform == "linux" else 0 + dataloader = DataLoader(dataset=buffer, batch_size=2, num_workers=num_workers) + output = [convert_data_type(x, np.ndarray)[0] for x in dataloader] + if num_workers == 0: + np.testing.assert_allclose(output, [[2, 1], [3, 4]]) + else: # multiprocess shuffle + np.testing.assert_allclose(output, [[2, 3], [1, 4]]) + + +if __name__ == "__main__": + unittest.main()