Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions monai/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,54 @@


class DataLoader(_TorchDataLoader):
"""Generates images/labels for train/validation/testing from dataset.
It inherits from PyTorch DataLoader and adds default callbacks for `collate`
and `worker_fn` if user doesn't set them.
"""
Provides an iterable over the given `dataset`. It inherits the PyTorch
DataLoader and adds enhanced `collate_fn` and `worker_fn` by default.

Although this class could be configured to be the same as
`torch.utils.data.DataLoader`, its default configuration is
recommended, mainly for the following extra features:

More information about PyTorch DataLoader, please check:
- It handles MONAI randomizable objects with appropriate random state
managements for deterministic behaviour.
- It is aware of the patch-based transform (such as
:py:class:`monai.transforms.RandSpatialCropSamplesDict`) samples for
preprocessing with enhanced data collating behaviour.
See: :py:class:`monai.transforms.Compose`.

For more details about :py:class:`torch.utils.data.DataLoader`, please see:
https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py

For example, to construct a randomized dataset and iterate with the data loader:

.. code-block:: python

import torch

from monai.data import DataLoader
from monai.transforms import Randomizable


class RandomDataset(torch.utils.data.Dataset, Randomizable):
def __getitem__(self, index):
return self.R.randint(0, 1000, (1,))

def __len__(self):
return 16


dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=2, num_workers=4)
for epoch in range(2):
for i, batch in enumerate(dataloader):
print(epoch, i, batch.data.numpy().flatten().tolist())

Args:
dataset: dataset from which to load the data.
num_workers: how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``0``)
kwargs: other parameters for PyTorch DataLoader.

"""

def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs) -> None:
Expand Down
5 changes: 4 additions & 1 deletion monai/transforms/croppad/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class PadListDataCollate(InvertibleTransform):
pass the inverse through multiprocessing.

Args:
batch: batch of data to pad-collate
method: padding method (see :py:class:`monai.transforms.SpatialPad`)
mode: padding mode (see :py:class:`monai.transforms.SpatialPad`)
"""
Expand All @@ -72,6 +71,10 @@ def __init__(
self.mode = mode

def __call__(self, batch: Any):
"""
Args:
batch: batch of data to pad-collate
"""
# data is either list of dicts or list of lists
is_list_of_dicts = isinstance(batch[0], dict)
# loop over items inside of each element in a batch
Expand Down
8 changes: 6 additions & 2 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,12 @@ def _log_stats(data, prefix: Optional[str] = "Data"):

class Randomizable(ABC):
"""
An interface for handling random state locally, currently based on a class variable `R`,
which is an instance of `np.random.RandomState`.
An interface for handling random state locally, currently based on a class
variable `R`, which is an instance of `np.random.RandomState`. This
provides the flexibility of component-specific determinism without
affecting the global states. It is recommended to use this API with
:py:class:`monai.data.DataLoader` for deterministic behaviour of the
preprocessing pipelines.
"""

R: np.random.RandomState = np.random.RandomState()
Expand Down
32 changes: 31 additions & 1 deletion tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from parameterized import parameterized

from monai.data import CacheDataset, DataLoader, Dataset
from monai.transforms import Compose, DataStatsd, SimulateDelayd
from monai.transforms import Compose, DataStatsd, Randomizable, SimulateDelayd
from monai.utils import set_determinism

TEST_CASE_1 = [
[
Expand Down Expand Up @@ -64,5 +65,34 @@ def test_exception(self, datalist):
pass


class _RandomDataset(torch.utils.data.Dataset, Randomizable):
def __getitem__(self, index):
return self.R.randint(0, 1000, (1,))

def __len__(self):
return 8


class TestLoaderRandom(unittest.TestCase):
"""
Testing data loader working with the randomizable interface
"""

def setUp(self):
set_determinism(0)

def tearDown(self):
set_determinism(None)

def test_randomize(self):
dataset = _RandomDataset()
dataloader = DataLoader(dataset, batch_size=2, num_workers=3)
output = []
for _ in range(2):
for batch in dataloader:
output.extend(batch.data.numpy().flatten().tolist())
self.assertListEqual(output, [594, 170, 524, 778, 370, 906, 292, 589, 762, 763, 156, 886, 42, 405, 221, 166])


if __name__ == "__main__":
unittest.main()