Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ Patch-based dataset
.. autoclass:: GridPatchDataset
:members:

`PatchIter`
~~~~~~~~~~~
.. autoclass:: PatchIter
:members:

`PatchDataset`
~~~~~~~~~~~~~~
.. autoclass:: PatchDataset
Expand Down
2 changes: 1 addition & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
ZipDataset,
)
from .decathlon_datalist import load_decathlon_datalist, load_decathlon_properties
from .grid_dataset import GridPatchDataset, PatchDataset
from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter
from .image_dataset import ImageDataset
from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader
from .iterable_dataset import IterableDataset
Expand Down
137 changes: 111 additions & 26 deletions monai/data/grid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Callable, Dict, Optional, Sequence, Union

import numpy as np
import torch
from torch.utils.data import IterableDataset

Expand All @@ -20,64 +20,149 @@
from monai.transforms import apply_transform
from monai.utils import NumpyPadMode, ensure_tuple

__all__ = ["PatchDataset", "GridPatchDataset"]
__all__ = ["PatchDataset", "GridPatchDataset", "PatchIter"]


class GridPatchDataset(IterableDataset):
class PatchIter:
Comment thread
wyli marked this conversation as resolved.
"""
Yields patches from arrays read from an input dataset. The patches are chosen in a contiguous grid sampling scheme.
A class to return a patch generator with predefined properties such as `patch_size`.
Typically used with :py:class:`monai.data.GridPatchDataset`.
"""

def __init__(
self,
dataset: Sequence,
patch_size: Sequence[int],
start_pos: Sequence[int] = (),
mode: Union[NumpyPadMode, str] = NumpyPadMode.WRAP,
**pad_opts: Dict,
) -> None:
):
"""
Initializes this dataset in terms of the input dataset and patch size. The `patch_size` is the size of the
patch to sample from the input arrays. It is assumed the arrays first dimension is the channel dimension which
will be yielded in its entirety so this should not be specified in `patch_size`. For example, for an input 3D
array with 1 channel of size (1, 20, 20, 20) a regular grid sampling of eight patches (1, 10, 10, 10) would be
specified by a `patch_size` of (10, 10, 10).

Args:
dataset: the dataset to read array data from
patch_size: size of patches to generate slices for, 0/None selects whole dimension
start_pos: starting position in the array, default is 0 for each dimension
mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
One of the listed string values or a user supplied function. Defaults to ``"wrap"``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
pad_opts: padding options, see numpy.pad
"""

self.dataset = dataset
Note:
The `patch_size` is the size of the
patch to sample from the input arrays. It is assumed the arrays first dimension is the channel dimension which
will be yielded in its entirety so this should not be specified in `patch_size`. For example, for an input 3D
array with 1 channel of size (1, 20, 20, 20) a regular grid sampling of eight patches (1, 10, 10, 10) would be
specified by a `patch_size` of (10, 10, 10).

"""
self.patch_size = (None,) + tuple(patch_size)
self.start_pos = ensure_tuple(start_pos)
self.mode: NumpyPadMode = NumpyPadMode(mode)
self.pad_opts = pad_opts

def __call__(self, array):
"""
Args:
array: the image to generate patches from.
"""
yield from iter_patch(
array,
patch_size=self.patch_size, # expand to have the channel dim
start_pos=self.start_pos,
copy_back=False,
mode=self.mode,
**self.pad_opts,
)


class GridPatchDataset(IterableDataset):
"""
Yields patches from images read from an image dataset.
Typically used with `PatchIter` so that the patches are chosen in a contiguous grid sampling scheme.

.. code-block:: python

import numpy as np

from monai.data import GridPatchDataset, DataLoader, PatchIter
from monai.transforms import RandShiftIntensity

# image-level dataset
images = [np.arange(16, dtype=float).reshape(1, 4, 4),
np.arange(16, dtype=float).reshape(1, 4, 4)]
# image-level patch generator, "grid sampling"
patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0))
# patch-level intensity shifts
patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0)

# construct the dataset
ds = GridPatchDataset(dataset=images,
patch_iter=patch_iter,
transform=patch_intensity)
# use the grid patch dataset
for item in DataLoader(ds, batch_size=2, num_workers=2):
print("patch size:", item[0].shape)
print("coordinates:", item[1])

# >>> patch size: torch.Size([2, 1, 2, 2])
# coordinates: tensor([[[0, 1], [0, 2], [0, 2]],
# [[0, 1], [2, 4], [0, 2]]])

"""

def __init__(
self,
dataset: Sequence,
patch_iter: Callable,
transform: Optional[Callable] = None,
with_coordinates: bool = True,
) -> None:
"""
Initializes this dataset in terms of the image dataset, patch generator, and an optional transform.

Args:
dataset: the dataset to read image data from.
patch_iter: converts an input image (item from dataset) into a iterable of image patches.
`patch_iter(dataset[idx])` must yield a tuple: (patches, coordinates).
see also: :py:class:`monai.data.PatchIter`.
transform: a callable data transform operates on the patches.
with_coordinates: whether to yield the coordinates of each patch, default to `True`.

"""

self.dataset = dataset
self.patch_iter = patch_iter
self.transform = transform
self.with_coordinates = with_coordinates

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
iter_start = 0
iter_end = len(self.dataset)
iter_start, iter_end = 0, 1
try:
iter_end = len(self.dataset) # TODO: support iterable self.dataset
except TypeError:
raise NotImplementedError("image dataset must implement `len()`.")

if worker_info is not None:
# split workload
per_worker = int(math.ceil((iter_end - iter_start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = iter_start + worker_id * per_worker
per_worker = int(np.ceil((iter_end - iter_start) / float(worker_info.num_workers)))
iter_start = iter_start + worker_info.id * per_worker
iter_end = min(iter_start + per_worker, iter_end)

for index in range(iter_start, iter_end):
arrays = self.dataset[index]

iters = [iter_patch(a, self.patch_size, self.start_pos, False, self.mode, **self.pad_opts) for a in arrays]

yield from zip(*iters)
image = self.dataset[index]
if not self.with_coordinates:
for patch, *_ in self.patch_iter(image): # patch_iter to yield at least 1 item: patch
out_patch = (
patch if self.transform is None else apply_transform(self.transform, patch, map_items=False)
)
yield out_patch
else:
for patch, slices, *_ in self.patch_iter(image): # patch_iter to yield at least 2 items: patch, coords
out_patch = (
patch if self.transform is None else apply_transform(self.transform, patch, map_items=False)
)
yield out_patch, slices


class PatchDataset(Dataset):
Expand All @@ -95,8 +180,8 @@ class PatchDataset(Dataset):
from monai.transforms import RandSpatialCropSamples, RandShiftIntensity

# image dataset
images = [np.arange(16, dtype=np.float).reshape(1, 4, 4),
np.arange(16, dtype=np.float).reshape(1, 4, 4)]
images = [np.arange(16, dtype=float).reshape(1, 4, 4),
np.arange(16, dtype=float).reshape(1, 4, 4)]
# image patch sampler
n_samples = 5
sampler = RandSpatialCropSamples(roi_size=(3, 3), num_samples=n_samples,
Expand Down
17 changes: 15 additions & 2 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def iter_patch(
copy_back: bool = True,
mode: Union[NumpyPadMode, str] = NumpyPadMode.WRAP,
**pad_opts: Dict,
) -> Generator[np.ndarray, None, None]:
):
"""
Yield successive patches from `arr` of size `patch_size`. The iteration can start from position `start_pos` in `arr`
but drawing from a padded array extended by the `patch_size` in each dimension (so these coordinates can be negative
Expand All @@ -194,6 +194,15 @@ def iter_patch(
Yields:
Patches of array data from `arr` which are views into a padded array which can be modified, if `copy_back` is
True these changes will be reflected in `arr` once the iteration completes.

Note:
coordinate format is:

[1st_dim_start, 1st_dim_end,
2nd_dim_start, 2nd_dim_end,
...,
Nth_dim_start, Nth_dim_end]]

"""
# ensure patchSize and startPos are the right length
patch_size_ = get_valid_patch_size(arr.shape, patch_size)
Expand All @@ -210,7 +219,9 @@ def iter_patch(
iter_size = tuple(s + p for s, p in zip(arr.shape, patch_size_))

for slices in iter_patch_slices(iter_size, patch_size_, start_pos_padded):
yield arrpad[slices]
# compensate original image padding
coords_no_pad = tuple((coord.start - p, coord.stop - p) for coord, p in zip(slices, patch_size_))
yield arrpad[slices], np.asarray(coords_no_pad) # data and coords (in numpy; works with torch loader)

# copy back data from the padded image if required
if copy_back:
Expand Down Expand Up @@ -411,6 +422,8 @@ def set_rnd(obj, seed: int) -> int:
obj.set_random_state(seed=seed % MAX_SEED)
return seed + 1 # a different seed for the next component
for key in obj.__dict__:
if key.startswith("__"): # skip the private methods
continue
seed = set_rnd(obj.__dict__[key], seed=seed)
return seed

Expand Down
83 changes: 83 additions & 0 deletions tests/test_grid_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import unittest

import numpy as np

from monai.data import DataLoader, GridPatchDataset, PatchIter
from monai.transforms import RandShiftIntensity
from monai.utils import set_determinism


def identity_generator(x):
# simple transform that returns the input itself
for idx, item in enumerate(x):
yield item, idx


class TestGridPatchDataset(unittest.TestCase):
def setUp(self):
set_determinism(seed=1234)

def tearDown(self):
set_determinism(None)

def test_shape(self):
test_dataset = ["vwxyz", "helloworld", "worldfoobar"]
result = GridPatchDataset(dataset=test_dataset, patch_iter=identity_generator, with_coordinates=False)
output = []
n_workers = 0 if sys.platform == "win32" else 2
for item in DataLoader(result, batch_size=3, num_workers=n_workers):
output.append("".join(item))
expected = ["vwx", "wor", "yzh", "ldf", "ell", "oob", "owo", "ar", "rld"]
self.assertEqual(sorted(output), sorted(expected))
self.assertEqual(len("".join(expected)), len("".join(test_dataset)))

def test_loading_array(self):
set_determinism(seed=1234)
# image dataset
images = [np.arange(16, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)]
# image level
patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0)
patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0))
ds = GridPatchDataset(dataset=images, patch_iter=patch_iter, transform=patch_intensity)
# use the grid patch dataset
for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=0):
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
np.testing.assert_allclose(
item[0],
np.array([[[[1.7413, 2.7413], [5.7413, 6.7413]]], [[[9.1419, 10.1419], [13.1419, 14.1419]]]]),
rtol=1e-5,
)
np.testing.assert_allclose(
item[1],
np.array([[[0, 1], [0, 2], [2, 4]], [[0, 1], [2, 4], [2, 4]]]),
rtol=1e-5,
)
if sys.platform != "win32":
for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=2):
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
np.testing.assert_allclose(
item[0],
np.array([[[[2.3944, 3.3944], [6.3944, 7.3944]]], [[[10.6551, 11.6551], [14.6551, 15.6551]]]]),
rtol=1e-3,
)
np.testing.assert_allclose(
item[1],
np.array([[[0, 1], [0, 2], [2, 4]], [[0, 1], [2, 4], [2, 4]]]),
rtol=1e-5,
)


if __name__ == "__main__":
unittest.main()