diff --git a/docs/source/data.rst b/docs/source/data.rst index eed4b30ded..3dffeb8977 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -77,6 +77,11 @@ Patch-based dataset .. autoclass:: GridPatchDataset :members: +`PatchIter` +~~~~~~~~~~~ +.. autoclass:: PatchIter + :members: + `PatchDataset` ~~~~~~~~~~~~~~ .. autoclass:: PatchDataset diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 54ee7908f4..9fa5c935e2 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -22,7 +22,7 @@ ZipDataset, ) from .decathlon_datalist import load_decathlon_datalist, load_decathlon_properties -from .grid_dataset import GridPatchDataset, PatchDataset +from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader from .iterable_dataset import IterableDataset diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index f85569d88a..3f373491ed 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -9,9 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Callable, Dict, Optional, Sequence, Union +import numpy as np import torch from torch.utils.data import IterableDataset @@ -20,31 +20,25 @@ from monai.transforms import apply_transform from monai.utils import NumpyPadMode, ensure_tuple -__all__ = ["PatchDataset", "GridPatchDataset"] +__all__ = ["PatchDataset", "GridPatchDataset", "PatchIter"] -class GridPatchDataset(IterableDataset): +class PatchIter: """ - Yields patches from arrays read from an input dataset. The patches are chosen in a contiguous grid sampling scheme. + A class to return a patch generator with predefined properties such as `patch_size`. + Typically used with :py:class:`monai.data.GridPatchDataset`. """ def __init__( self, - dataset: Sequence, patch_size: Sequence[int], start_pos: Sequence[int] = (), mode: Union[NumpyPadMode, str] = NumpyPadMode.WRAP, **pad_opts: Dict, - ) -> None: + ): """ - Initializes this dataset in terms of the input dataset and patch size. The `patch_size` is the size of the - patch to sample from the input arrays. It is assumed the arrays first dimension is the channel dimension which - will be yielded in its entirety so this should not be specified in `patch_size`. For example, for an input 3D - array with 1 channel of size (1, 20, 20, 20) a regular grid sampling of eight patches (1, 10, 10, 10) would be - specified by a `patch_size` of (10, 10, 10). Args: - dataset: the dataset to read array data from patch_size: size of patches to generate slices for, 0/None selects whole dimension start_pos: starting position in the array, default is 0 for each dimension mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, @@ -52,32 +46,123 @@ def __init__( One of the listed string values or a user supplied function. Defaults to ``"wrap"``. See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html pad_opts: padding options, see numpy.pad - """ - self.dataset = dataset + Note: + The `patch_size` is the size of the + patch to sample from the input arrays. It is assumed the arrays first dimension is the channel dimension which + will be yielded in its entirety so this should not be specified in `patch_size`. For example, for an input 3D + array with 1 channel of size (1, 20, 20, 20) a regular grid sampling of eight patches (1, 10, 10, 10) would be + specified by a `patch_size` of (10, 10, 10). + + """ self.patch_size = (None,) + tuple(patch_size) self.start_pos = ensure_tuple(start_pos) self.mode: NumpyPadMode = NumpyPadMode(mode) self.pad_opts = pad_opts + def __call__(self, array): + """ + Args: + array: the image to generate patches from. + """ + yield from iter_patch( + array, + patch_size=self.patch_size, # expand to have the channel dim + start_pos=self.start_pos, + copy_back=False, + mode=self.mode, + **self.pad_opts, + ) + + +class GridPatchDataset(IterableDataset): + """ + Yields patches from images read from an image dataset. + Typically used with `PatchIter` so that the patches are chosen in a contiguous grid sampling scheme. + + .. code-block:: python + + import numpy as np + + from monai.data import GridPatchDataset, DataLoader, PatchIter + from monai.transforms import RandShiftIntensity + + # image-level dataset + images = [np.arange(16, dtype=float).reshape(1, 4, 4), + np.arange(16, dtype=float).reshape(1, 4, 4)] + # image-level patch generator, "grid sampling" + patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0)) + # patch-level intensity shifts + patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0) + + # construct the dataset + ds = GridPatchDataset(dataset=images, + patch_iter=patch_iter, + transform=patch_intensity) + # use the grid patch dataset + for item in DataLoader(ds, batch_size=2, num_workers=2): + print("patch size:", item[0].shape) + print("coordinates:", item[1]) + + # >>> patch size: torch.Size([2, 1, 2, 2]) + # coordinates: tensor([[[0, 1], [0, 2], [0, 2]], + # [[0, 1], [2, 4], [0, 2]]]) + + """ + + def __init__( + self, + dataset: Sequence, + patch_iter: Callable, + transform: Optional[Callable] = None, + with_coordinates: bool = True, + ) -> None: + """ + Initializes this dataset in terms of the image dataset, patch generator, and an optional transform. + + Args: + dataset: the dataset to read image data from. + patch_iter: converts an input image (item from dataset) into a iterable of image patches. + `patch_iter(dataset[idx])` must yield a tuple: (patches, coordinates). + see also: :py:class:`monai.data.PatchIter`. + transform: a callable data transform operates on the patches. + with_coordinates: whether to yield the coordinates of each patch, default to `True`. + + """ + + self.dataset = dataset + self.patch_iter = patch_iter + self.transform = transform + self.with_coordinates = with_coordinates + def __iter__(self): worker_info = torch.utils.data.get_worker_info() - iter_start = 0 - iter_end = len(self.dataset) + iter_start, iter_end = 0, 1 + try: + iter_end = len(self.dataset) # TODO: support iterable self.dataset + except TypeError: + raise NotImplementedError("image dataset must implement `len()`.") if worker_info is not None: # split workload - per_worker = int(math.ceil((iter_end - iter_start) / float(worker_info.num_workers))) - worker_id = worker_info.id - iter_start = iter_start + worker_id * per_worker + per_worker = int(np.ceil((iter_end - iter_start) / float(worker_info.num_workers))) + iter_start = iter_start + worker_info.id * per_worker iter_end = min(iter_start + per_worker, iter_end) for index in range(iter_start, iter_end): - arrays = self.dataset[index] - - iters = [iter_patch(a, self.patch_size, self.start_pos, False, self.mode, **self.pad_opts) for a in arrays] - - yield from zip(*iters) + image = self.dataset[index] + if not self.with_coordinates: + for patch, *_ in self.patch_iter(image): # patch_iter to yield at least 1 item: patch + out_patch = ( + patch if self.transform is None else apply_transform(self.transform, patch, map_items=False) + ) + yield out_patch + else: + for patch, slices, *_ in self.patch_iter(image): # patch_iter to yield at least 2 items: patch, coords + out_patch = ( + patch if self.transform is None else apply_transform(self.transform, patch, map_items=False) + ) + yield out_patch, slices class PatchDataset(Dataset): @@ -95,8 +180,8 @@ class PatchDataset(Dataset): from monai.transforms import RandSpatialCropSamples, RandShiftIntensity # image dataset - images = [np.arange(16, dtype=np.float).reshape(1, 4, 4), - np.arange(16, dtype=np.float).reshape(1, 4, 4)] + images = [np.arange(16, dtype=float).reshape(1, 4, 4), + np.arange(16, dtype=float).reshape(1, 4, 4)] # image patch sampler n_samples = 5 sampler = RandSpatialCropSamples(roi_size=(3, 3), num_samples=n_samples, diff --git a/monai/data/utils.py b/monai/data/utils.py index 60250af441..2e2f8c00cb 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -174,7 +174,7 @@ def iter_patch( copy_back: bool = True, mode: Union[NumpyPadMode, str] = NumpyPadMode.WRAP, **pad_opts: Dict, -) -> Generator[np.ndarray, None, None]: +): """ Yield successive patches from `arr` of size `patch_size`. The iteration can start from position `start_pos` in `arr` but drawing from a padded array extended by the `patch_size` in each dimension (so these coordinates can be negative @@ -194,6 +194,15 @@ def iter_patch( Yields: Patches of array data from `arr` which are views into a padded array which can be modified, if `copy_back` is True these changes will be reflected in `arr` once the iteration completes. + + Note: + coordinate format is: + + [1st_dim_start, 1st_dim_end, + 2nd_dim_start, 2nd_dim_end, + ..., + Nth_dim_start, Nth_dim_end]] + """ # ensure patchSize and startPos are the right length patch_size_ = get_valid_patch_size(arr.shape, patch_size) @@ -210,7 +219,9 @@ def iter_patch( iter_size = tuple(s + p for s, p in zip(arr.shape, patch_size_)) for slices in iter_patch_slices(iter_size, patch_size_, start_pos_padded): - yield arrpad[slices] + # compensate original image padding + coords_no_pad = tuple((coord.start - p, coord.stop - p) for coord, p in zip(slices, patch_size_)) + yield arrpad[slices], np.asarray(coords_no_pad) # data and coords (in numpy; works with torch loader) # copy back data from the padded image if required if copy_back: @@ -411,6 +422,8 @@ def set_rnd(obj, seed: int) -> int: obj.set_random_state(seed=seed % MAX_SEED) return seed + 1 # a different seed for the next component for key in obj.__dict__: + if key.startswith("__"): # skip the private methods + continue seed = set_rnd(obj.__dict__[key], seed=seed) return seed diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py new file mode 100644 index 0000000000..6e0aa4023e --- /dev/null +++ b/tests/test_grid_dataset.py @@ -0,0 +1,83 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np + +from monai.data import DataLoader, GridPatchDataset, PatchIter +from monai.transforms import RandShiftIntensity +from monai.utils import set_determinism + + +def identity_generator(x): + # simple transform that returns the input itself + for idx, item in enumerate(x): + yield item, idx + + +class TestGridPatchDataset(unittest.TestCase): + def setUp(self): + set_determinism(seed=1234) + + def tearDown(self): + set_determinism(None) + + def test_shape(self): + test_dataset = ["vwxyz", "helloworld", "worldfoobar"] + result = GridPatchDataset(dataset=test_dataset, patch_iter=identity_generator, with_coordinates=False) + output = [] + n_workers = 0 if sys.platform == "win32" else 2 + for item in DataLoader(result, batch_size=3, num_workers=n_workers): + output.append("".join(item)) + expected = ["vwx", "wor", "yzh", "ldf", "ell", "oob", "owo", "ar", "rld"] + self.assertEqual(sorted(output), sorted(expected)) + self.assertEqual(len("".join(expected)), len("".join(test_dataset))) + + def test_loading_array(self): + set_determinism(seed=1234) + # image dataset + images = [np.arange(16, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)] + # image level + patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0) + patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0)) + ds = GridPatchDataset(dataset=images, patch_iter=patch_iter, transform=patch_intensity) + # use the grid patch dataset + for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=0): + np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) + np.testing.assert_allclose( + item[0], + np.array([[[[1.7413, 2.7413], [5.7413, 6.7413]]], [[[9.1419, 10.1419], [13.1419, 14.1419]]]]), + rtol=1e-5, + ) + np.testing.assert_allclose( + item[1], + np.array([[[0, 1], [0, 2], [2, 4]], [[0, 1], [2, 4], [2, 4]]]), + rtol=1e-5, + ) + if sys.platform != "win32": + for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=2): + np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) + np.testing.assert_allclose( + item[0], + np.array([[[[2.3944, 3.3944], [6.3944, 7.3944]]], [[[10.6551, 11.6551], [14.6551, 15.6551]]]]), + rtol=1e-3, + ) + np.testing.assert_allclose( + item[1], + np.array([[[0, 1], [0, 2], [2, 4]], [[0, 1], [2, 4], [2, 4]]]), + rtol=1e-5, + ) + + +if __name__ == "__main__": + unittest.main()