diff --git a/docs/source/data.rst b/docs/source/data.rst index c95659bc6e..6ed6be9702 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -68,6 +68,12 @@ Generic Interfaces .. autoclass:: ImageDataset :members: :special-members: __getitem__ + +`NPZDictItemDataset` +~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: NPZDictItemDataset + :members: + :special-members: __getitem__ Patch-based dataset ------------------- diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 54beb53e3f..2a7647e527 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -17,6 +17,7 @@ CacheNTransDataset, Dataset, LMDBDataset, + NPZDictItemDataset, PersistentDataset, SmartCacheDataset, ZipDataset, @@ -32,7 +33,7 @@ from .png_writer import write_png from .samplers import DistributedSampler, DistributedWeightedRandomSampler from .synthetic import create_test_image_2d, create_test_image_3d -from .thread_buffer import ThreadBuffer +from .thread_buffer import ThreadBuffer, ThreadDataLoader from .utils import ( compute_importance_map, compute_shape_offset, diff --git a/monai/data/dataset.py b/monai/data/dataset.py index c032e65af6..c10c500bf8 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -19,12 +19,13 @@ from copy import deepcopy from multiprocessing.pool import ThreadPool from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union +from typing import IO, TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union +import numpy as np import torch from torch.utils.data import Dataset as _TorchDataset -from monai.data.utils import pickle_hashing +from monai.data.utils import first, pickle_hashing from monai.transforms import Compose, Randomizable, Transform, apply_transform from monai.transforms.transform import RandomizableTransform from monai.utils import MAX_SEED, get_seed, min_version, optional_import @@ -931,3 +932,53 @@ def __getitem__(self, index: int): if isinstance(transform, RandomizableTransform): transform.set_random_state(seed=self._seed) return self.dataset[index] + + +class NPZDictItemDataset(Dataset): + """ + Represents a dataset from a loaded NPZ file. The members of the file to load are named in the keys of `keys` and + stored under the keyed name. All loaded arrays must have the same 0-dimension (batch) size. Items are always dicts + mapping names to an item extracted from the loaded arrays. + + Args: + npzfile: Path to .npz file or stream containing .npz file data + keys: Maps keys to load from file to name to store in dataset + transform: Transform to apply to batch dict + other_keys: secondary data to load from file and store in dict `other_keys`, not returned by __getitem__ + """ + + def __init__( + self, + npzfile: Union[str, IO], + keys: Dict[str, str], + transform: Optional[Callable] = None, + other_keys: Optional[Sequence[str]] = (), + ): + self.npzfile: Union[str, IO] = npzfile if isinstance(npzfile, str) else "STREAM" + self.keys: Dict[str, str] = dict(keys) + dat = np.load(npzfile) + + self.arrays = {storedk: dat[datak] for datak, storedk in self.keys.items()} + self.length = self.arrays[first(self.keys.values())].shape[0] + + self.other_keys = {} if other_keys is None else {k: dat[k] for k in other_keys} + + for k, v in self.arrays.items(): + if v.shape[0] != self.length: + raise ValueError( + "All loaded arrays must have the same first dimension " + f"size {self.length}, array `{k}` has size {v.shape[0]}" + ) + + super().__init__([], transform) + + def __len__(self): + return self.length + + def __getitem__(self, index: int): + data = {k: v[index] for k, v in self.arrays.items()} + + if self.transform is not None: + data = apply_transform(self.transform, data) + + return data diff --git a/monai/data/thread_buffer.py b/monai/data/thread_buffer.py index 252fdd6a21..da5f864900 100644 --- a/monai/data/thread_buffer.py +++ b/monai/data/thread_buffer.py @@ -13,6 +13,8 @@ from queue import Empty, Full, Queue from threading import Thread +from monai.data import DataLoader, Dataset + class ThreadBuffer: """ @@ -73,3 +75,20 @@ def __iter__(self): pass # queue was empty this time, try again finally: self.stop() # ensure thread completion + + +class ThreadDataLoader(DataLoader): + """ + Subclass of `DataLoader` using a `ThreadBuffer` object to implement `__iter__` method asynchronously. This will + iterate over data from the loader as expected however the data is generated on a separate thread. Use this class + where a `DataLoader` instance is required and not just an iterable object. + """ + + def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs): + super().__init__(dataset, num_workers, **kwargs) + + # ThreadBuffer will use the inherited __iter__ instead of the one defined below + self.buffer = ThreadBuffer(super()) + + def __iter__(self): + yield from self.buffer diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 8f73f7f2fd..5669e8a9ee 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -17,7 +17,7 @@ from .iteration_metric import IterationMetric from .lr_schedule_handler import LrScheduleHandler from .mean_dice import MeanDice -from .metric_logger import MetricLogger +from .metric_logger import MetricLogger, MetricLoggerKeys from .metrics_saver import MetricsSaver from .roc_auc import ROCAUC from .segmentation_saver import SegmentationSaver diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py index 758276d03d..c749d4bbab 100644 --- a/monai/handlers/metric_logger.py +++ b/monai/handlers/metric_logger.py @@ -10,8 +10,11 @@ # limitations under the License. from collections import defaultdict -from typing import TYPE_CHECKING, Callable, DefaultDict, List +from enum import Enum +from threading import RLock +from typing import TYPE_CHECKING, Callable, DefaultDict, List, Optional +from monai.engines.utils import CommonKeys from monai.utils import exact_version, optional_import Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") @@ -21,12 +24,43 @@ Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") +def _get_loss_from_output(output, loss_key: str = CommonKeys.LOSS): + return output[loss_key].item() + + +class MetricLoggerKeys(Enum): + METRICS = "Metrics" + LOSS = "Loss" + + class MetricLogger: - def __init__(self, loss_transform: Callable = lambda x: x, metric_transform: Callable = lambda x: x) -> None: + """ + Collect per-iteration metrics and loss value from the attached trainer. This will also collect metric values from + a given evaluator object which is expected to perform evaluation at the end of training epochs. This class is + useful for collecting loss and metric values in one place for storage with checkpoint savers (`state_dict` and + `load_state_dict` methods provided as expected by Pytorch and Ignite) and for graphing during training. + + Args: + loss_transform: Converts the `output` value from the trainer's state into a loss value + metric_transform: Converts the metric value coming from the trainer/evaluator's state into a storable value + evaluator: Optional evaluator to consume metric results from at the end of its evaluation run + """ + + def __init__( + self, + loss_transform: Callable = _get_loss_from_output, + metric_transform: Callable = lambda x: x, + evaluator: Optional[Engine] = None, + ) -> None: self.loss_transform = loss_transform self.metric_transform = metric_transform self.loss: List = [] self.metrics: DefaultDict = defaultdict(list) + self.iteration = 0 + self.lock = RLock() + + if evaluator is not None: + self.attach_evaluator(evaluator) def attach(self, engine: Engine) -> None: """ @@ -35,21 +69,46 @@ def attach(self, engine: Engine) -> None: """ engine.add_event_handler(Events.ITERATION_COMPLETED, self) + def attach_evaluator(self, evaluator: Engine) -> None: + """ + Attach event handlers to the given evaluator to log metric values from it. + + Args: + evaluator: Ignite Engine implementing network evaluation + """ + evaluator.add_event_handler(Events.COMPLETED, self.log_metrics) + def __call__(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - self.loss.append(self.loss_transform(engine.state.output)) + with self.lock: + self.iteration = engine.state.iteration + lossval = self.loss_transform(engine.state.output) + + self.loss.append((self.iteration, lossval)) + self.log_metrics(engine) + + def log_metrics(self, engine: Engine) -> None: + """ + Log metrics from the given Engine's state member. + + Args: + engine: Ignite Engine to log from + """ + with self.lock: + for m, v in engine.state.metrics.items(): + v = self.metric_transform(v) + self.metrics[m].append((self.iteration, v)) - for m, v in engine.state.metrics.items(): - v = self.metric_transform(v) - # # metrics may not be added on the first timestep, pad the list if this is the case - # # so that each metric list is the same length as self.loss - # if len(self.metrics[m])==0: - # self.metrics[m].append([v[0]]*len(self.loss)) + def state_dict(self): + return {MetricLoggerKeys.LOSS: self.loss, MetricLoggerKeys.METRICS: self.metrics} - self.metrics[m].append(v) + def load_state_dict(self, state_dict): + self.loss[:] = state_dict[MetricLoggerKeys.LOSS] + self.metrics.clear() + self.metrics.update(state_dict[MetricLoggerKeys.METRICS]) metriclogger = MetricLogger diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 3c1e7efe24..4d272ac6ff 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -31,6 +31,7 @@ UpsampleMode, Weight, ) +from .jupyter_utils import StatusMembers, ThreadContainer from .misc import ( MAX_SEED, ImageMetaKey, diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py new file mode 100644 index 0000000000..a7e712619e --- /dev/null +++ b/monai/utils/jupyter_utils.py @@ -0,0 +1,351 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This set of utility function is meant to make using Jupyter notebooks easier with MONAI. Plotting functions using +Matplotlib produce common plots for metrics and images. +""" + +from enum import Enum +from threading import RLock, Thread +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + +# from monai.utils import exact_version, optional_import + +# if TYPE_CHECKING: +# import matplotlib.pyplot as plt +# from ignite.engine import Engine, Events + +# Figure = plt.Figure +# Axes = plt.Axes +# has_matplotlib = True +# else: +# Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") +# Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") +# plt, has_matplotlib = optional_import("matplotlib.pyplot") +# Figure, _ = optional_import("matplotlib.pyplot", name="Figure") +# Axes, _ = optional_import("matplotlib.pyplot", name="Axes") + +try: + import matplotlib.pyplot as plt + + has_matplotlib = True +except ImportError: + has_matplotlib = False + +try: + from ignite.engine import Engine, Events + + has_ignite = True +except ImportError: + Engine = object + Events = object + has_ignite = False + +LOSS_NAME = "loss" + + +def plot_metric_graph( + ax, + title: str, + graphmap: Dict[str, Union[List[float], Tuple[List[float], List[float]]]], + yscale: str = "log", + avg_keys: Tuple[str] = (LOSS_NAME,), + window_fraction: int = 20, +): + """ + Plot metrics on a single graph with running averages plotted for selected keys. The values in `graphmap` + should be lists of (timepoint, value) pairs as stored in MetricLogger objects. + + Args: + ax: Axes object to plot into + title: graph title + graphmap: dictionary of named graph values, which are lists of values or (index, value) pairs + yscale: scale for y-axis compatible with `Axes.set_yscale` + avg_keys: tuple of keys in `graphmap` to provide running average plots for + window_fraction: what fraction of the graph value length to use as the running average window + """ + from matplotlib.ticker import MaxNLocator + + for n, v in graphmap.items(): + if len(v) > 0: + if isinstance(v[0], (tuple, list)): # values are (x,y) pairs + inds, vals = zip(*v) # separate values into list of indices in X dimension and values + else: + inds, vals = tuple(range(len(v))), tuple(v) # values are without indices, make indices for them + + ax.plot(inds, vals, label=f"{n} = {vals[-1]:.5g}") + + # if requested compute and plot a running average for the values using a fractional window size + if n in avg_keys and len(v) > window_fraction: + window = len(v) // window_fraction + kernel = np.ones((window,)) / window + ra = np.convolve((vals[0],) * (window - 1) + vals, kernel, mode="valid") + + ax.plot(inds, ra, label=f"{n} Avg = {ra[-1]:.5g}") + + ax.set_title(title) + ax.set_yscale(yscale) + ax.axis("on") + ax.legend(bbox_to_anchor=(1, 1), loc=1, borderaxespad=0.0) + ax.grid(True, "both", "both") + ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + + +def plot_metric_images( + fig, + title: str, + graphmap: Dict[str, Union[List[float], Tuple[List[float], List[float]]]], + imagemap: Dict[str, np.ndarray], + yscale: str = "log", + avg_keys: Tuple[str] = (LOSS_NAME,), + window_fraction: int = 20, +) -> List: + """ + Plot metric graph data with images below into figure `fig`. The intended use is for the graph data to be + metrics from a training run and the images to be the batch and output from the last iteration. This uses + `plot_metric_graph` to plot the metric graph. + + Args: + fig: Figure object to plot into, reuse from previous plotting for flicker-free refreshing + title: graph title + graphmap: dictionary of named graph values, which are lists of values or (index, value) pairs + imagemap: dictionary of named images to show with metric plot + yscale: for metric plot, scale for y-axis compatible with `Axes.set_yscale` + avg_keys: for metric plot, tuple of keys in `graphmap` to provide running average plots for + window_fraction: for metric plot, what fraction of the graph value length to use as the running average window + + Returns: + list of Axes objects for graph followed by images + """ + gridshape = (4, max(1, len(imagemap))) + + graph = plt.subplot2grid(gridshape, (0, 0), colspan=gridshape[1], fig=fig) + + plot_metric_graph(graph, title, graphmap, yscale, avg_keys, window_fraction) + + axes = [graph] + for i, n in enumerate(imagemap): + im = plt.subplot2grid(gridshape, (1, i), rowspan=2, fig=fig) + + if imagemap[n].shape[0] == 3: + im.imshow(imagemap[n].transpose([1, 2, 0])) + else: + im.imshow(np.squeeze(imagemap[n]), cmap="gray") + + im.set_title("%s\n%.3g -> %.3g" % (n, imagemap[n].min(), imagemap[n].max())) + im.axis("off") + axes.append(im) + + return axes + + +def tensor_to_images(name: str, tensor: torch.Tensor): + """ + Return an tuple of images derived from the given tensor. The `name` value indices which key from the + output or batch value the tensor was stored as, or is "Batch" or "Output" if these were single tensors + instead of dictionaries. Returns a tuple of 2D images of shape HW, or 3D images of shape CHW where C is + color channels RGB or RGBA. This allows multiple images to be created from a single tensor, ie. to show + each channel separately. + """ + if tensor.ndim == 4 and tensor.shape[2] > 2 and tensor.shape[3] > 2: + return tuple(tensor[0].cpu().data.numpy()) + elif tensor.ndim == 5 and tensor.shape[3] > 2 and tensor.shape[4] > 2: + dmid = tensor.shape[2] // 2 + return tuple(tensor[0, :, dmid].cpu().data.numpy()) + + return () + + +def plot_engine_status( + engine: Engine, + logger, + title: str = "Training Log", + yscale: str = "log", + avg_keys: Tuple[str] = (LOSS_NAME,), + window_fraction: int = 20, + image_fn: Optional[Callable] = tensor_to_images, + fig=None, +) -> Tuple: + """ + Plot the status of the given Engine with its logger. The plot will consist of a graph of loss values and metrics + taken from the logger, and images taken from the `output` and `batch` members of `engine.state`. The images are + converted to Numpy arrays suitable for input to `Axes.imshow` using `image_fn`, if this is None then no image + plotting is done. + + Args: + engine: Engine to extract images from + logger: MetricLogger to extract loss and metric data from + title: graph title + yscale: for metric plot, scale for y-axis compatible with `Axes.set_yscale` + avg_keys: for metric plot, tuple of keys in `graphmap` to provide running average plots for + window_fraction: for metric plot, what fraction of the graph value length to use as the running average window + image_fn: callable converting tensors keyed to a name in the Engine to a tuple of images to plot + fig: Figure object to plot into, reuse from previous plotting for flicker-free refreshing + + Returns: + Figure object (or `fig` if given), list of Axes objects for graph and images + """ + if fig is not None: + fig.clf() + else: + fig = plt.Figure(figsize=(20, 10), tight_layout=True, facecolor="white") + + graphmap = {LOSS_NAME: logger.loss} + graphmap.update(logger.metrics) + + imagemap = {} + + if image_fn is not None and engine.state is not None and engine.state.batch is not None: + for src in (engine.state.batch, engine.state.output): + if isinstance(src, dict): + for k, v in src.items(): + images = image_fn(k, v) + + for i, im in enumerate(images): + imagemap[f"{k}_{i}"] = im + else: + label = "Batch" if src is engine.state.batch else "Output" + images = image_fn(label, src) + + for i, im in enumerate(images): + imagemap[f"{label}_{i}"] = im + + axes = plot_metric_images(fig, title, graphmap, imagemap, yscale, avg_keys, window_fraction) + + axes[0].axhline(logger.loss[-1][1], c="k", ls=":") # draw dotted horizontal line at last loss value + + return fig, axes + + +def _get_loss_from_output(output: Union[Dict[str, torch.Tensor], torch.Tensor]) -> float: + """Returns a single value from the network output, which is a dict or tensor.""" + if isinstance(output, dict): + return output["loss"].item() + else: + return output.item() + + +class StatusMembers(Enum): + """ + Named members of the status dictionary, others may be present for named metric values. + """ + + STATUS = "Status" + EPOCHS = "Epochs" + ITERS = "Iters" + LOSS = "Loss" + + +class ThreadContainer(Thread): + """ + Contains a running `Engine` object within a separate thread from main thread in a Jupyter notebook. This + allows an engine to begin a run in the background and allow the starting notebook cell to complete. A + user can thus start a run and then navigate away from the notebook without concern for loosing connection + with the running cell. All output is acquired through methods which synchronize with the running engine + using an internal `lock` member, acquiring this lock allows the engine to be inspected while it's prevented + from starting the next iteration. + + Args: + engine: wrapped `Engine` object, when the container is started its `run` method is called + loss_transform: callable to convert an output dict into a single numeric value + metric_transform: callable to convert a named metric value into a single numeric value + """ + + def __init__( + self, + engine: Engine, + loss_transform: Callable = _get_loss_from_output, + metric_transform: Callable = lambda name, value: value, + ): + super().__init__() + self.lock = RLock() + self.engine = engine + self._status_dict: Dict[str, Any] = {} + self.loss_transform = loss_transform + self.metric_transform = metric_transform + self.fig = None + + self.engine.add_event_handler(Events.ITERATION_COMPLETED, self._update_status) + + def run(self): + """Calls the `run` method of the wrapped engine.""" + self.engine.run() + + def stop(self): + """Stop the engine and join the thread.""" + self.engine.terminate() + self.join() + + def _update_status(self): + """Called as an event, updates the internal status dict at the end of iterations.""" + with self.lock: + state = self.engine.state + stats = { + StatusMembers.EPOCHS.value: 0, + StatusMembers.ITERS.value: 0, + StatusMembers.LOSS.value: float("nan"), + } + + if state is not None: + if state.max_epochs >= 1: + epoch = f"{state.epoch}/{state.max_epochs}" + else: + epoch = str(state.epoch) + + if state.epoch_length is not None: + iters = f"{state.iteration % state.epoch_length}/{state.epoch_length}" + else: + iters = str(state.iteration) + + stats[StatusMembers.EPOCHS.value] = epoch + stats[StatusMembers.ITERS.value] = iters + stats[StatusMembers.LOSS.value] = self.loss_transform(state.output) + + metrics = state.metrics or {} + for m, v in metrics.items(): + v = self.metric_transform(m, v) + if v is not None: + stats[m].append(v) + + self._status_dict.update(stats) + + @property + def status_dict(self) -> Dict[str, str]: + """A dictionary containing status information, current loss, and current metric values.""" + with self.lock: + stats = {StatusMembers.STATUS.value: "Running" if self.is_alive else "Stopped"} + stats.update(self._status_dict) + return stats + + def status(self) -> str: + """Returns a status string for the current state of the engine.""" + stats = self.status_dict + + msgs = [stats.pop(StatusMembers.STATUS.value), "Iters: " + str(stats.pop(StatusMembers.ITERS.value))] + msgs += ["%s: %s" % kv for kv in stats.items()] + + return ", ".join(msgs) + + def plot_status(self, logger, plot_func: Callable = plot_engine_status): + """ + Generate a plot of the current status of the contained engine whose loss and metrics were tracked by `logger`. + The function `plot_func` must accept arguments `title`, `engine`, `logger`, and `fig` which are the plot title, + `self.engine`, `logger`, and `self.fig` respectively. The return value must be a figure object (stored in + `self.fig`) and a list of Axes objects for the plots in the figure. Only the figure is returned by this method, + which holds the internal lock during the plot generation. + """ + with self.lock: + self.fig, axes = plot_func(title=self.status(), engine=self.engine, logger=logger, fig=self.fig) + return self.fig diff --git a/tests/test_npzdictitemdataset.py b/tests/test_npzdictitemdataset.py new file mode 100644 index 0000000000..5ec52f45a2 --- /dev/null +++ b/tests/test_npzdictitemdataset.py @@ -0,0 +1,55 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +from io import BytesIO + +import numpy as np + +from monai.data import NPZDictItemDataset + + +class TestNPZDictItemDataset(unittest.TestCase): + def test_load_stream(self): + dat0 = np.random.rand(10, 1, 4, 4) + dat1 = np.random.rand(10, 1, 4, 4) + + npzfile = BytesIO() + npz = np.savez_compressed(npzfile, dat0=dat0, dat1=dat1) + npzfile.seek(0) + + npzds = NPZDictItemDataset(npzfile, {"dat0": "images", "dat1": "seg"}) + + item = npzds[0] + + np.testing.assert_allclose(item["images"].shape, (1, 4, 4)) + np.testing.assert_allclose(item["seg"].shape, (1, 4, 4)) + + def test_load_file(self): + dat0 = np.random.rand(10, 1, 4, 4) + dat1 = np.random.rand(10, 1, 4, 4) + + with tempfile.TemporaryDirectory() as tempdir: + npzfile = f"{tempdir}/test.npz" + + npz = np.savez_compressed(npzfile, dat0=dat0, dat1=dat1) + + npzds = NPZDictItemDataset(npzfile, {"dat0": "images", "dat1": "seg"}) + + item = npzds[0] + + np.testing.assert_allclose(item["images"].shape, (1, 4, 4)) + np.testing.assert_allclose(item["seg"].shape, (1, 4, 4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_thread_buffer.py b/tests/test_thread_buffer.py index d139b44c85..1b3ebb910d 100644 --- a/tests/test_thread_buffer.py +++ b/tests/test_thread_buffer.py @@ -13,7 +13,7 @@ import time import unittest -from monai.data import DataLoader, Dataset, ThreadBuffer +from monai.data import DataLoader, Dataset, ThreadBuffer, ThreadDataLoader from monai.transforms import Compose, SimulateDelayd from monai.utils import PerfContext @@ -41,6 +41,16 @@ def test_values(self): self.assertEqual(d["label"][0], "spleen_label_19.nii.gz") self.assertEqual(d["label"][1], "spleen_label_31.nii.gz") + def test_dataloader(self): + dataset = Dataset(data=self.datalist, transform=self.transform) + dataloader = ThreadDataLoader(dataset=dataset, batch_size=2, num_workers=0) + + for d in dataloader: + self.assertEqual(d["image"][0], "spleen_19.nii.gz") + self.assertEqual(d["image"][1], "spleen_31.nii.gz") + self.assertEqual(d["label"][0], "spleen_label_19.nii.gz") + self.assertEqual(d["label"][1], "spleen_label_31.nii.gz") + def test_time(self): dataset = Dataset(data=self.datalist * 2, transform=self.transform) # contains data for 2 batches dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0) diff --git a/tests/test_threadcontainer.py b/tests/test_threadcontainer.py new file mode 100644 index 0000000000..92a50a15aa --- /dev/null +++ b/tests/test_threadcontainer.py @@ -0,0 +1,58 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import unittest + +import torch + +from monai.utils import optional_import + +try: + _, has_ignite = optional_import("ignite") + + from monai.engines import CommonKeys, SupervisedTrainer + from monai.utils import ThreadContainer +except ImportError: + has_ignite = False + +from monai.data import DataLoader + + +class TestThreadContainer(unittest.TestCase): + @unittest.skipIf(not has_ignite, "Ignite needed for this test") + def test_container(self): + net = torch.nn.Conv2d(1, 1, 3, padding=1) + + opt = torch.optim.Adam(net.parameters()) + + img = torch.rand(1, 16, 16) + data = {CommonKeys.IMAGE: img, CommonKeys.LABEL: img} + loader = DataLoader([data for _ in range(10)]) + + trainer = SupervisedTrainer( + device=torch.device("cpu"), + max_epochs=1, + train_data_loader=loader, + network=net, + optimizer=opt, + loss_function=torch.nn.L1Loss(), + ) + + con = ThreadContainer(trainer) + con.start() + time.sleep(1) # wait for trainer to start + + self.assertTrue(con.is_alive) + self.assertIsNotNone(con.status()) + self.assertTrue(len(con.status_dict) > 0) + + con.join()