From 690ceb0f76ed80a698f802302bde8127e96e8fcd Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 16 Mar 2021 13:15:14 +0000 Subject: [PATCH 01/11] Jupyter and other additions Signed-off-by: Eric Kerfoot --- monai/data/__init__.py | 2 +- monai/data/thread_buffer.py | 18 +++++ monai/handlers/__init__.py | 2 +- monai/handlers/metric_logger.py | 78 ++++++++++++++++--- monai/utils/__init__.py | 1 + monai/utils/jupyter_utils.py | 128 ++++++++++++++++++++++++++++++++ 6 files changed, 217 insertions(+), 12 deletions(-) create mode 100644 monai/utils/jupyter_utils.py diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 3dd0a980ef..1e2dbd2de0 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -31,7 +31,7 @@ from .png_saver import PNGSaver from .png_writer import write_png from .synthetic import create_test_image_2d, create_test_image_3d -from .thread_buffer import ThreadBuffer +from .thread_buffer import ThreadBuffer, ThreadDataLoader from .utils import ( DistributedSampler, compute_importance_map, diff --git a/monai/data/thread_buffer.py b/monai/data/thread_buffer.py index 252fdd6a21..421dba3040 100644 --- a/monai/data/thread_buffer.py +++ b/monai/data/thread_buffer.py @@ -13,6 +13,8 @@ from queue import Empty, Full, Queue from threading import Thread +from monai.data import DataLoader, Dataset + class ThreadBuffer: """ @@ -73,3 +75,19 @@ def __iter__(self): pass # queue was empty this time, try again finally: self.stop() # ensure thread completion + + +class ThreadDataLoader(DataLoader): + """ + Subclass of `DataLoader` using a `ThreadBuffer` object to implement `__iter__` method asynchronously. This will + iterate over data from the loader as expected however the data is generated on a separate thread. Use this class + where a `DataLoader` instance is required and not just an iterable object. + """ + + def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs): + super().__init__(dataset, num_workers, **kwargs) + + # ThreadBuffer will use the inherited __iter__ instead of the one defined below + self.buffer = ThreadBuffer(super()) + self.__iter__ = self.buffer.__iter__ + \ No newline at end of file diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 8f73f7f2fd..5669e8a9ee 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -17,7 +17,7 @@ from .iteration_metric import IterationMetric from .lr_schedule_handler import LrScheduleHandler from .mean_dice import MeanDice -from .metric_logger import MetricLogger +from .metric_logger import MetricLogger, MetricLoggerKeys from .metrics_saver import MetricsSaver from .roc_auc import ROCAUC from .segmentation_saver import SegmentationSaver diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py index fdd60da57c..26329fba22 100644 --- a/monai/handlers/metric_logger.py +++ b/monai/handlers/metric_logger.py @@ -10,7 +10,9 @@ # limitations under the License. from collections import defaultdict -from typing import TYPE_CHECKING, Callable, DefaultDict, List +from enum import Enum +from threading import RLock +from typing import TYPE_CHECKING, Callable, DefaultDict, List, Optional from monai.utils import exact_version, optional_import @@ -21,12 +23,43 @@ Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") +def _get_loss_from_output(output): + return output["loss"].item() + + +class MetricLoggerKeys(Enum): + METRICS = "Metrics" + LOSS = "Loss" + + class MetricLogger: - def __init__(self, loss_transform: Callable = lambda x: x, metric_transform: Callable = lambda x: x) -> None: + """ + Collect per-iteration metrics and loss value from the attached trainer. This will also collect metric values from + a given evaluator object which is expected to perform evaluation at the end of training epochs. This class is + useful for collecting loss and metric values in one place for storage with checkpoint savers (`state_dict` and + `load_state_dict` methods provided as expected by Pytorch and Ignite) and for graphing during training. + + Args: + loss_transform: Converts the `output` value from the trainer's state into a loss value + metric_transform: Converts the metric value coming from the trainer/evaluator's state into a storable value + evaluator: Optional evaluator to consume metric results from at the end of its evaluation run + """ + + def __init__( + self, + loss_transform: Callable = _get_loss_from_output, + metric_transform: Callable = lambda x: x, + evaluator: Optional[Engine] = None, + ) -> None: self.loss_transform = loss_transform self.metric_transform = metric_transform self.loss: List = [] self.metrics: DefaultDict = defaultdict(list) + self.iteration = 0 + self.lock = RLock() + + if evaluator is not None: + self.attach_evaluator(evaluator) def attach(self, engine: Engine) -> None: """ @@ -35,21 +68,46 @@ def attach(self, engine: Engine) -> None: """ engine.add_event_handler(Events.ITERATION_COMPLETED, self) + def attach_evaluator(self, evaluator: Engine) -> None: + """ + Attach event handlers to the given evaluator to log metric values from it. + + Args: + evaluator: Ignite Engine implementing network evaluation + """ + evaluator.add_event_handler(Events.COMPLETED, self.log_metrics) + def __call__(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - self.loss.append(self.loss_transform(engine.state.output)) + with self.lock: + self.iteration = engine.state.iteration + lossval = self.loss_transform(engine.state.output) + + self.loss.append((self.iteration, lossval)) + self.log_metrics(engine) + + def log_metrics(self, engine: Engine) -> None: + """ + Log metrics from the given Engine's state member. + + Args: + engine: Ignite Engine to log from + """ + with self.lock: + for m, v in engine.state.metrics.items(): + v = self.metric_transform(v) + self.metrics[m].append((self.iteration, v)) - for m, v in engine.state.metrics.items(): - v = self.metric_transform(v) - # # metrics may not be added on the first timestep, pad the list if this is the case - # # so that each metric list is the same length as self.loss - # if len(self.metrics[m])==0: - # self.metrics[m].append([v[0]]*len(self.loss)) + def state_dict(self): + return {MetricLoggerKeys.LOSS.value: self.loss, MetricLoggerKeys.METRICS: self.metrics.value} - self.metrics[m].append(v) + def load_state_dict(self, state_dict): + self.loss[:] = state_dict[MetricLoggerKeys.LOSS.value] + self.metrics.clear() + self.metrics.update(state_dict[MetricLoggerKeys.METRICS.value]) metriclogger = MetricLogger diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 1e17d44029..3a0e225c6f 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -67,3 +67,4 @@ ) from .profiling import PerfContext, torch_profiler_full, torch_profiler_time_cpu_gpu, torch_profiler_time_end_to_end from .state_cacher import StateCacher +from .jupyter_utils import StatusMembers, ThreadContainer diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py new file mode 100644 index 0000000000..50431ae349 --- /dev/null +++ b/monai/utils/jupyter_utils.py @@ -0,0 +1,128 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from threading import RLock, Thread +from typing import TYPE_CHECKING, Callable, Dict + +from monai.utils import exact_version, optional_import + + +if TYPE_CHECKING: + from ignite.engine import Engine, Events + import matplotlib.pyplot as plt + + has_matplotlib = True +else: + Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") + Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") + plt, has_matplotlib = optional_import("matplotlib.pyplot") + + +def _get_loss_from_output(output): + return output["loss"].item() + + +class StatusMembers(Enum): + """ + Named members of the status dictionary, others may be present for named metric values. + """ + + STATUS = "Status" + EPOCHS = "Epochs" + ITERS = "Iters" + LOSS = "Loss" + + +class ThreadContainer(Thread): + """ + Contains a running `Engine` object within a separate thread from main thread in a Jupyter notebook. This + allows an engine to begin a run in the background and allow the starting notebook cell to complete. A + user can thus start a run and then navigate away from the notebook without concern for loosing connection + with the running cell. All output is acquired through methods which synchronize with the running engine + using an internal `lock` member, acquiring this lock allows the engine to be inspected while it's prevented + from starting the next iteration. + + Args: + engine: wrapped `Engine` object, when the container is started its `run` method is called + loss_transform: callable to convert an output dict into a single numeric value + metric_transform: callable to convert a metric value into a single numeric value + """ + + def __init__( + self, engine: Engine, loss_transform: Callable = _get_loss_from_output, metric_transform: Callable = lambda x: x + ): + super().__init__() + self.lock = RLock() + self.engine = engine + self._status_dict = {} + self.loss_transform = loss_transform + self.metric_transform = metric_transform + + self.engine.add_event_handler(Events.ITERATION_COMPLETED, self._update_status) + + def run(self): + """Calls the `run` method of the wrapped engine.""" + self.engine.run() + + def stop(self): + """Stop the engine and join the thread.""" + self.engine.terminate() + self.join() + + def is_running(self) -> bool: + """Returns True if the thread is still alive and the engine is still running.""" + return self.is_alive() + + def _update_status(self): + """Called as an event, updates the internal status dict at the end of iterations.""" + with self.lock: + state = self.engine.state + + if state is not None: + if state.max_epochs > 1: + epoch = f"{state.epoch}/{state.max_epochs}" + else: + epoch = str(state.epoch) + + if state.epoch_length is not None: + iters = f"{state.iteration % state.epoch_length}/{state.epoch_length}" + else: + iters = str(state.iteration) + + stats[StatusMembers.EPOCHS.value] = epoch + stats[StatusMembers.ITERS.value] = iters + stats[StatusMembers.LOSS.value] = self.loss_transform(state.output) + + metrics = state.metrics or {} + for m, v in metrics.items(): + v = self.metric_transform(v) + if v is not None: + stats[m].append(v) + + self._status_dict.update(stats) + + @property + def status_dict(self) -> Dict[str, str]: + """A dictionary containing status information, current loss, and current metric values.""" + with self.lock: + stats = {StatusMembers.STATUS.value: "Running" if self.is_running() else "Stopped"} + stats.update(self._status_dict) + return stats + + def status(self) -> str: + """Returns a status string for the current state of the engine.""" + stats = self.status_dict + + msgs = [stats.pop(StatusMembers.STATUS.value), "Iters: " + str(stats.pop(StatusMembers.ITERS.value))] + msgs += ["%s: %s" % kv for kv in stats.items()] + + return ", ".join(msgs) From 2911f2b319fc4cf6c8001bf74f1d0f934a907625 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 17 Mar 2021 20:18:23 +0000 Subject: [PATCH 02/11] Jupyter utilities update Signed-off-by: Eric Kerfoot --- monai/data/__init__.py | 1 + monai/data/dataset.py | 55 ++++++- monai/data/thread_buffer.py | 7 +- monai/handlers/metric_logger.py | 14 +- monai/utils/__init__.py | 2 +- monai/utils/jupyter_utils.py | 259 ++++++++++++++++++++++++++++--- tests/test_npzdictitemdataset.py | 55 +++++++ tests/test_thread_buffer.py | 12 +- tests/test_threadcontainer.py | 49 ++++++ 9 files changed, 418 insertions(+), 36 deletions(-) create mode 100644 tests/test_npzdictitemdataset.py create mode 100644 tests/test_threadcontainer.py diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 6687f4d8be..2a7647e527 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -17,6 +17,7 @@ CacheNTransDataset, Dataset, LMDBDataset, + NPZDictItemDataset, PersistentDataset, SmartCacheDataset, ZipDataset, diff --git a/monai/data/dataset.py b/monai/data/dataset.py index c032e65af6..c10c500bf8 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -19,12 +19,13 @@ from copy import deepcopy from multiprocessing.pool import ThreadPool from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union +from typing import IO, TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union +import numpy as np import torch from torch.utils.data import Dataset as _TorchDataset -from monai.data.utils import pickle_hashing +from monai.data.utils import first, pickle_hashing from monai.transforms import Compose, Randomizable, Transform, apply_transform from monai.transforms.transform import RandomizableTransform from monai.utils import MAX_SEED, get_seed, min_version, optional_import @@ -931,3 +932,53 @@ def __getitem__(self, index: int): if isinstance(transform, RandomizableTransform): transform.set_random_state(seed=self._seed) return self.dataset[index] + + +class NPZDictItemDataset(Dataset): + """ + Represents a dataset from a loaded NPZ file. The members of the file to load are named in the keys of `keys` and + stored under the keyed name. All loaded arrays must have the same 0-dimension (batch) size. Items are always dicts + mapping names to an item extracted from the loaded arrays. + + Args: + npzfile: Path to .npz file or stream containing .npz file data + keys: Maps keys to load from file to name to store in dataset + transform: Transform to apply to batch dict + other_keys: secondary data to load from file and store in dict `other_keys`, not returned by __getitem__ + """ + + def __init__( + self, + npzfile: Union[str, IO], + keys: Dict[str, str], + transform: Optional[Callable] = None, + other_keys: Optional[Sequence[str]] = (), + ): + self.npzfile: Union[str, IO] = npzfile if isinstance(npzfile, str) else "STREAM" + self.keys: Dict[str, str] = dict(keys) + dat = np.load(npzfile) + + self.arrays = {storedk: dat[datak] for datak, storedk in self.keys.items()} + self.length = self.arrays[first(self.keys.values())].shape[0] + + self.other_keys = {} if other_keys is None else {k: dat[k] for k in other_keys} + + for k, v in self.arrays.items(): + if v.shape[0] != self.length: + raise ValueError( + "All loaded arrays must have the same first dimension " + f"size {self.length}, array `{k}` has size {v.shape[0]}" + ) + + super().__init__([], transform) + + def __len__(self): + return self.length + + def __getitem__(self, index: int): + data = {k: v[index] for k, v in self.arrays.items()} + + if self.transform is not None: + data = apply_transform(self.transform, data) + + return data diff --git a/monai/data/thread_buffer.py b/monai/data/thread_buffer.py index 421dba3040..da5f864900 100644 --- a/monai/data/thread_buffer.py +++ b/monai/data/thread_buffer.py @@ -88,6 +88,7 @@ def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs): super().__init__(dataset, num_workers, **kwargs) # ThreadBuffer will use the inherited __iter__ instead of the one defined below - self.buffer = ThreadBuffer(super()) - self.__iter__ = self.buffer.__iter__ - \ No newline at end of file + self.buffer = ThreadBuffer(super()) + + def __iter__(self): + yield from self.buffer diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py index f2abc87605..43a06722b2 100644 --- a/monai/handlers/metric_logger.py +++ b/monai/handlers/metric_logger.py @@ -35,10 +35,10 @@ class MetricLoggerKeys(Enum): class MetricLogger: """ Collect per-iteration metrics and loss value from the attached trainer. This will also collect metric values from - a given evaluator object which is expected to perform evaluation at the end of training epochs. This class is + a given evaluator object which is expected to perform evaluation at the end of training epochs. This class is useful for collecting loss and metric values in one place for storage with checkpoint savers (`state_dict` and `load_state_dict` methods provided as expected by Pytorch and Ignite) and for graphing during training. - + Args: loss_transform: Converts the `output` value from the trainer's state into a loss value metric_transform: Converts the metric value coming from the trainer/evaluator's state into a storable value @@ -71,7 +71,7 @@ def attach(self, engine: Engine) -> None: def attach_evaluator(self, evaluator: Engine) -> None: """ Attach event handlers to the given evaluator to log metric values from it. - + Args: evaluator: Ignite Engine implementing network evaluation """ @@ -92,7 +92,7 @@ def __call__(self, engine: Engine) -> None: def log_metrics(self, engine: Engine) -> None: """ Log metrics from the given Engine's state member. - + Args: engine: Ignite Engine to log from """ @@ -102,12 +102,12 @@ def log_metrics(self, engine: Engine) -> None: self.metrics[m].append((self.iteration, v)) def state_dict(self): - return {MetricLoggerKeys.LOSS.value: self.loss, MetricLoggerKeys.METRICS: self.metrics.value} + return {MetricLoggerKeys.LOSS: self.loss, MetricLoggerKeys.METRICS: self.metrics} def load_state_dict(self, state_dict): - self.loss[:] = state_dict[MetricLoggerKeys.LOSS.value] + self.loss[:] = state_dict[MetricLoggerKeys.LOSS] self.metrics.clear() - self.metrics.update(state_dict[MetricLoggerKeys.METRICS.value]) + self.metrics.update(state_dict[MetricLoggerKeys.METRICS]) metriclogger = MetricLogger diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index c17f0bf532..4d272ac6ff 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -31,6 +31,7 @@ UpsampleMode, Weight, ) +from .jupyter_utils import StatusMembers, ThreadContainer from .misc import ( MAX_SEED, ImageMetaKey, @@ -68,4 +69,3 @@ ) from .profiling import PerfContext, torch_profiler_full, torch_profiler_time_cpu_gpu, torch_profiler_time_end_to_end from .state_cacher import StateCacher -from .jupyter_utils import StatusMembers, ThreadContainer diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index 50431ae349..ceecf11af4 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -9,26 +9,224 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +This set of utility function is meant to make using Jupyter notebooks easier with MONAI. Plotting functions using +Matplotlib produce common plots for metrics and images. +""" + from enum import Enum from threading import RLock, Thread -from typing import TYPE_CHECKING, Callable, Dict +from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from monai.utils import exact_version, optional_import +import numpy as np +import torch +# from monai.utils import exact_version, optional_import -if TYPE_CHECKING: - from ignite.engine import Engine, Events +# if TYPE_CHECKING: +# import matplotlib.pyplot as plt +# from ignite.engine import Engine, Events + +# Figure = plt.Figure +# Axes = plt.Axes +# has_matplotlib = True +# else: +# Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") +# Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") +# plt, has_matplotlib = optional_import("matplotlib.pyplot") +# Figure, _ = optional_import("matplotlib.pyplot", name="Figure") +# Axes, _ = optional_import("matplotlib.pyplot", name="Axes") + +try: import matplotlib.pyplot as plt + from ignite.engine import Engine, Events has_matplotlib = True -else: - Engine, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Engine") - Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") - plt, has_matplotlib = optional_import("matplotlib.pyplot") +except ImportError: + has_matplotlib = False + +LOSS_NAME = "loss" + + +def plot_metric_graph( + ax, + title: str, + graphmap: Dict[str, Union[List[float], Tuple[List[float], List[float]]]], + yscale: str = "log", + avg_keys: Tuple[str] = (LOSS_NAME,), + window_fraction: int = 20, +): + """ + Plot metrics on a single graph with running averages plotted for selected keys. The values in `graphmap` + should be lists of (timepoint, value) pairs as stored in MetricLogger objects. + + Args: + ax: Axes object to plot into + title: graph title + graphmap: dictionary of named graph values, which are lists of values or (index, value) pairs + yscale: scale for y-axis compatible with `Axes.set_yscale` + avg_keys: tuple of keys in `graphmap` to provide running average plots for + window_fraction: what fraction of the graph value length to use as the running average window + """ + from matplotlib.ticker import MaxNLocator + + for n, v in graphmap.items(): + if len(v) > 0: + if isinstance(v[0], (tuple, list)): # values are (x,y) pairs + inds, vals = zip(*v) # separate values into list of indices in X dimension and values + else: + inds, vals = tuple(range(len(v))), tuple(v) # values are without indices, make indices for them + + ax.plot(inds, vals, label=f"{n} = {vals[-1]:.5g}") + + # if requested compute and plot a running average for the values using a fractional window size + if n in avg_keys and len(v) > window_fraction: + window = len(v) // window_fraction + kernel = np.ones((window,)) / window + ra = np.convolve((vals[0],) * (window - 1) + vals, kernel, mode="valid") + + ax.plot(inds, ra, label=f"{n} Avg = {ra[-1]:.5g}") + + ax.set_title(title) + ax.set_yscale(yscale) + ax.axis("on") + ax.legend(bbox_to_anchor=(1, 1), loc=1, borderaxespad=0.0) + ax.grid(True, "both", "both") + ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + + +def plot_metric_images( + fig, + title: str, + graphmap: Dict[str, Union[List[float], Tuple[List[float], List[float]]]], + imagemap: Dict[str, np.ndarray], + yscale: str = "log", + avg_keys: Tuple[str] = (LOSS_NAME,), + window_fraction: int = 20, +) -> List: + """ + Plot metric graph data with images below into figure `fig`. The intended use is for the graph data to be + metrics from a training run and the images to be the batch and output from the last iteration. This uses + `plot_metric_graph` to plot the metric graph. + + Args: + fig: Figure object to plot into, reuse from previous plotting for flicker-free refreshing + title: graph title + graphmap: dictionary of named graph values, which are lists of values or (index, value) pairs + imagemap: dictionary of named images to show with metric plot + yscale: for metric plot, scale for y-axis compatible with `Axes.set_yscale` + avg_keys: for metric plot, tuple of keys in `graphmap` to provide running average plots for + window_fraction: for metric plot, what fraction of the graph value length to use as the running average window + + Returns: + list of Axes objects for graph followed by images + """ + gridshape = (4, max(1, len(imagemap))) + + graph = plt.subplot2grid(gridshape, (0, 0), colspan=gridshape[1], fig=fig) + + plot_metric_graph(graph, title, graphmap, yscale, avg_keys, window_fraction) + + axes = [graph] + for i, n in enumerate(imagemap): + im = plt.subplot2grid(gridshape, (1, i), rowspan=2, fig=fig) + if imagemap[n].shape[0] == 3: + im.imshow(imagemap[n].transpose([1, 2, 0])) + else: + im.imshow(np.squeeze(imagemap[n]), cmap="gray") -def _get_loss_from_output(output): - return output["loss"].item() + im.set_title("%s\n%.3g -> %.3g" % (n, imagemap[n].min(), imagemap[n].max())) + im.axis("off") + axes.append(im) + + return axes + + +def tensor_to_images(name: str, tensor: torch.Tensor): + """ + Return an tuple of images derived from the given tensor. The `name` value indices which key from the + output or batch value the tensor was stored as, or is "Batch" or "Output" if these were single tensors + instead of dictionaries. Returns a tuple of 2D images of shape HW, or 3D images of shape CHW where C is + color channels RGB or RGBA. This allows multiple images to be created from a single tensor, ie. to show + each channel separately. + """ + if tensor.ndim == 4 and tensor.shape[2] > 2 and tensor.shape[3] > 2: + return tuple(tensor[0].cpu().data.numpy()) + elif tensor.ndim == 5 and tensor.shape[3] > 2 and tensor.shape[4] > 2: + dmid = tensor.shape[2] // 2 + return tuple(tensor[0, :, dmid].cpu().data.numpy()) + + return () + + +def plot_engine_status( + engine: Engine, + logger, + title: str = "Training Log", + yscale: str = "log", + avg_keys: Tuple[str] = (LOSS_NAME,), + window_fraction: int = 20, + image_fn: Optional[Callable] = tensor_to_images, + fig=None, +) -> Tuple: + """ + Plot the status of the given Engine with its logger. The plot will consist of a graph of loss values and metrics + taken from the logger, and images taken from the `output` and `batch` members of `engine.state`. The images are + converted to Numpy arrays suitable for input to `Axes.imshow` using `image_fn`, if this is None then no image + plotting is done. + + Args: + engine: Engine to extract images from + logger: MetricLogger to extract loss and metric data from + title: graph title + yscale: for metric plot, scale for y-axis compatible with `Axes.set_yscale` + avg_keys: for metric plot, tuple of keys in `graphmap` to provide running average plots for + window_fraction: for metric plot, what fraction of the graph value length to use as the running average window + image_fn: callable converting tensors keyed to a name in the Engine to a tuple of images to plot + fig: Figure object to plot into, reuse from previous plotting for flicker-free refreshing + + Returns: + Figure object (or `fig` if given), list of Axes objects for graph and images + """ + if fig is not None: + fig.clf() + else: + fig = plt.Figure(figsize=(20, 10), tight_layout=True, facecolor="white") + + graphmap = {LOSS_NAME: logger.loss} + graphmap.update(logger.metrics) + + imagemap = {} + + if image_fn is not None and engine.state is not None and engine.state.batch is not None: + for src in (engine.state.batch, engine.state.output): + if isinstance(src, dict): + for k, v in src.items(): + images = image_fn(k, v) + + for i, im in enumerate(images): + imagemap[f"{k}_{i}"] = im + else: + label = "Batch" if src is engine.state.batch else "Output" + images = image_fn(label, src) + + for i, im in enumerate(images): + imagemap[f"{label}_{i}"] = im + + axes = plot_metric_images(fig, title, graphmap, imagemap, yscale, avg_keys, window_fraction) + + axes[0].axhline(logger.loss[-1][1], c="k", ls=":") # draw dotted horizontal line at last loss value + + return fig, axes + + +def _get_loss_from_output(output: Union[Dict[str, torch.Tensor], torch.Tensor]) -> float: + """Returns a single value from the network output, which is a dict or tensor.""" + if isinstance(output, dict): + return output["loss"].item() + else: + return output.item() class StatusMembers(Enum): @@ -44,28 +242,32 @@ class StatusMembers(Enum): class ThreadContainer(Thread): """ - Contains a running `Engine` object within a separate thread from main thread in a Jupyter notebook. This + Contains a running `Engine` object within a separate thread from main thread in a Jupyter notebook. This allows an engine to begin a run in the background and allow the starting notebook cell to complete. A user can thus start a run and then navigate away from the notebook without concern for loosing connection with the running cell. All output is acquired through methods which synchronize with the running engine using an internal `lock` member, acquiring this lock allows the engine to be inspected while it's prevented from starting the next iteration. - + Args: engine: wrapped `Engine` object, when the container is started its `run` method is called loss_transform: callable to convert an output dict into a single numeric value - metric_transform: callable to convert a metric value into a single numeric value + metric_transform: callable to convert a named metric value into a single numeric value """ def __init__( - self, engine: Engine, loss_transform: Callable = _get_loss_from_output, metric_transform: Callable = lambda x: x + self, + engine: Engine, + loss_transform: Callable = _get_loss_from_output, + metric_transform: Callable = lambda name, value: value, ): super().__init__() self.lock = RLock() self.engine = engine - self._status_dict = {} + self._status_dict: Dict[str, Any] = {} self.loss_transform = loss_transform self.metric_transform = metric_transform + self.fig = None self.engine.add_event_handler(Events.ITERATION_COMPLETED, self._update_status) @@ -78,17 +280,18 @@ def stop(self): self.engine.terminate() self.join() - def is_running(self) -> bool: - """Returns True if the thread is still alive and the engine is still running.""" - return self.is_alive() - def _update_status(self): """Called as an event, updates the internal status dict at the end of iterations.""" with self.lock: state = self.engine.state + stats = { + StatusMembers.EPOCHS.value: 0, + StatusMembers.ITERS.value: 0, + StatusMembers.LOSS.value: float("nan"), + } if state is not None: - if state.max_epochs > 1: + if state.max_epochs >= 1: epoch = f"{state.epoch}/{state.max_epochs}" else: epoch = str(state.epoch) @@ -104,7 +307,7 @@ def _update_status(self): metrics = state.metrics or {} for m, v in metrics.items(): - v = self.metric_transform(v) + v = self.metric_transform(m, v) if v is not None: stats[m].append(v) @@ -114,7 +317,7 @@ def _update_status(self): def status_dict(self) -> Dict[str, str]: """A dictionary containing status information, current loss, and current metric values.""" with self.lock: - stats = {StatusMembers.STATUS.value: "Running" if self.is_running() else "Stopped"} + stats = {StatusMembers.STATUS.value: "Running" if self.is_alive else "Stopped"} stats.update(self._status_dict) return stats @@ -126,3 +329,15 @@ def status(self) -> str: msgs += ["%s: %s" % kv for kv in stats.items()] return ", ".join(msgs) + + def plot_status(self, logger, plot_func: Callable = plot_engine_status): + """ + Generate a plot of the current status of the contained engine whose loss and metrics were tracked by `logger`. + The function `plot_func` must accept arguments `title`, `engine`, `logger`, and `fig` which are the plot title, + `self.engine`, `logger`, and `self.fig` respectively. The return value must be a figure object (stored in + `self.fig`) and a list of Axes objects for the plots in the figure. Only the figure is returned by this method, + which holds the internal lock during the plot generation. + """ + with self.lock: + self.fig, axes = plot_func(title=self.status(), engine=self.engine, logger=logger, fig=self.fig) + return self.fig diff --git a/tests/test_npzdictitemdataset.py b/tests/test_npzdictitemdataset.py new file mode 100644 index 0000000000..5ec52f45a2 --- /dev/null +++ b/tests/test_npzdictitemdataset.py @@ -0,0 +1,55 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +from io import BytesIO + +import numpy as np + +from monai.data import NPZDictItemDataset + + +class TestNPZDictItemDataset(unittest.TestCase): + def test_load_stream(self): + dat0 = np.random.rand(10, 1, 4, 4) + dat1 = np.random.rand(10, 1, 4, 4) + + npzfile = BytesIO() + npz = np.savez_compressed(npzfile, dat0=dat0, dat1=dat1) + npzfile.seek(0) + + npzds = NPZDictItemDataset(npzfile, {"dat0": "images", "dat1": "seg"}) + + item = npzds[0] + + np.testing.assert_allclose(item["images"].shape, (1, 4, 4)) + np.testing.assert_allclose(item["seg"].shape, (1, 4, 4)) + + def test_load_file(self): + dat0 = np.random.rand(10, 1, 4, 4) + dat1 = np.random.rand(10, 1, 4, 4) + + with tempfile.TemporaryDirectory() as tempdir: + npzfile = f"{tempdir}/test.npz" + + npz = np.savez_compressed(npzfile, dat0=dat0, dat1=dat1) + + npzds = NPZDictItemDataset(npzfile, {"dat0": "images", "dat1": "seg"}) + + item = npzds[0] + + np.testing.assert_allclose(item["images"].shape, (1, 4, 4)) + np.testing.assert_allclose(item["seg"].shape, (1, 4, 4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_thread_buffer.py b/tests/test_thread_buffer.py index d139b44c85..1b3ebb910d 100644 --- a/tests/test_thread_buffer.py +++ b/tests/test_thread_buffer.py @@ -13,7 +13,7 @@ import time import unittest -from monai.data import DataLoader, Dataset, ThreadBuffer +from monai.data import DataLoader, Dataset, ThreadBuffer, ThreadDataLoader from monai.transforms import Compose, SimulateDelayd from monai.utils import PerfContext @@ -41,6 +41,16 @@ def test_values(self): self.assertEqual(d["label"][0], "spleen_label_19.nii.gz") self.assertEqual(d["label"][1], "spleen_label_31.nii.gz") + def test_dataloader(self): + dataset = Dataset(data=self.datalist, transform=self.transform) + dataloader = ThreadDataLoader(dataset=dataset, batch_size=2, num_workers=0) + + for d in dataloader: + self.assertEqual(d["image"][0], "spleen_19.nii.gz") + self.assertEqual(d["image"][1], "spleen_31.nii.gz") + self.assertEqual(d["label"][0], "spleen_label_19.nii.gz") + self.assertEqual(d["label"][1], "spleen_label_31.nii.gz") + def test_time(self): dataset = Dataset(data=self.datalist * 2, transform=self.transform) # contains data for 2 batches dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0) diff --git a/tests/test_threadcontainer.py b/tests/test_threadcontainer.py new file mode 100644 index 0000000000..2e04ca20f1 --- /dev/null +++ b/tests/test_threadcontainer.py @@ -0,0 +1,49 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import unittest + +import torch + +from monai.data import DataLoader +from monai.engines import CommonKeys, SupervisedTrainer +from monai.utils import ThreadContainer + + +class TestThreadContainer(unittest.TestCase): + def test_container(self): + net = torch.nn.Conv2d(1, 1, 3, padding=1) + + opt = torch.optim.Adam(net.parameters()) + + img = torch.rand(1, 16, 16) + data = {CommonKeys.IMAGE: img, CommonKeys.LABEL: img} + loader = DataLoader([data for _ in range(10)]) + + trainer = SupervisedTrainer( + device=torch.device("cpu"), + max_epochs=1, + train_data_loader=loader, + network=net, + optimizer=opt, + loss_function=torch.nn.L1Loss(), + ) + + con = ThreadContainer(trainer) + con.start() + time.sleep(1) # wait for trainer to start + + self.assertTrue(con.is_alive) + self.assertIsNotNone(con.status()) + self.assertTrue(len(con.status_dict) > 0) + + con.join() From 7787978ad122a2762f7b27790d4897fef271f1cc Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 17 Mar 2021 20:25:17 +0000 Subject: [PATCH 03/11] Jupyter utilities update Signed-off-by: Eric Kerfoot --- monai/utils/jupyter_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index ceecf11af4..48438e7f44 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -39,11 +39,16 @@ try: import matplotlib.pyplot as plt - from ignite.engine import Engine, Events has_matplotlib = True except ImportError: has_matplotlib = False + +try: + from ignite.engine import Engine, Events + has_ignite = True +except ImportError: + has_ignite = False LOSS_NAME = "loss" From 1679397dc57c09cfe8a3a9441f5ee1e0eaede7dc Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 17 Mar 2021 20:29:18 +0000 Subject: [PATCH 04/11] Jupyter utilities update Signed-off-by: Eric Kerfoot --- docs/source/data.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/data.rst b/docs/source/data.rst index c95659bc6e..83e6d17c92 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -68,6 +68,12 @@ Generic Interfaces .. autoclass:: ImageDataset :members: :special-members: __getitem__ + +`NPZDictItemDataset` +~~~~~~~~~~~~~~ +.. autoclass:: NPZDictItemDataset + :members: + :special-members: __getitem__ Patch-based dataset ------------------- From 4c8053b321e8a634602598ee7a6d9cddd09aff09 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 17 Mar 2021 20:41:39 +0000 Subject: [PATCH 05/11] Jupyter utilities update Signed-off-by: Eric Kerfoot --- monai/utils/jupyter_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index 48438e7f44..a7e712619e 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -43,11 +43,14 @@ has_matplotlib = True except ImportError: has_matplotlib = False - + try: from ignite.engine import Engine, Events + has_ignite = True except ImportError: + Engine = object + Events = object has_ignite = False LOSS_NAME = "loss" From 433f93203d8d3842b67d8aa56cb51cf65b686aa2 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 17 Mar 2021 20:49:13 +0000 Subject: [PATCH 06/11] Jupyter utilities update Signed-off-by: Eric Kerfoot --- docs/source/data.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/data.rst b/docs/source/data.rst index 83e6d17c92..6ed6be9702 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -70,7 +70,7 @@ Generic Interfaces :special-members: __getitem__ `NPZDictItemDataset` -~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~ .. autoclass:: NPZDictItemDataset :members: :special-members: __getitem__ From 18112e138f11030f61817f4fae0866c3190e1cd5 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 17 Mar 2021 21:13:58 +0000 Subject: [PATCH 07/11] Jupyter utilities update Signed-off-by: Eric Kerfoot --- tests/test_threadcontainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_threadcontainer.py b/tests/test_threadcontainer.py index 2e04ca20f1..9c6964e4c5 100644 --- a/tests/test_threadcontainer.py +++ b/tests/test_threadcontainer.py @@ -17,9 +17,11 @@ from monai.data import DataLoader from monai.engines import CommonKeys, SupervisedTrainer from monai.utils import ThreadContainer +from tests.utils import skip_if_quick class TestThreadContainer(unittest.TestCase): + @skip_if_quick def test_container(self): net = torch.nn.Conv2d(1, 1, 3, padding=1) From dd730db470a1ca428eed9c05cd1052abff6fed2b Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 17 Mar 2021 21:33:25 +0000 Subject: [PATCH 08/11] Jupyter utilities update Signed-off-by: Eric Kerfoot --- tests/test_threadcontainer.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/test_threadcontainer.py b/tests/test_threadcontainer.py index 9c6964e4c5..abf6de89c1 100644 --- a/tests/test_threadcontainer.py +++ b/tests/test_threadcontainer.py @@ -14,14 +14,21 @@ import torch +try: + import ignite + + from monai.engines import CommonKeys, SupervisedTrainer + from monai.utils import ThreadContainer + + has_ignite = True +except ImportError: + has_ignite = False + from monai.data import DataLoader -from monai.engines import CommonKeys, SupervisedTrainer -from monai.utils import ThreadContainer -from tests.utils import skip_if_quick class TestThreadContainer(unittest.TestCase): - @skip_if_quick + @unittest.skipIf(not has_ignite) def test_container(self): net = torch.nn.Conv2d(1, 1, 3, padding=1) From 57ba600175889c093b171b5f72e498975ad5a9e7 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 17 Mar 2021 21:41:59 +0000 Subject: [PATCH 09/11] Jupyter utilities update Signed-off-by: Eric Kerfoot --- tests/test_threadcontainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_threadcontainer.py b/tests/test_threadcontainer.py index abf6de89c1..594b818eff 100644 --- a/tests/test_threadcontainer.py +++ b/tests/test_threadcontainer.py @@ -14,13 +14,13 @@ import torch +from monai.utils import optional_import + try: - import ignite + _, has_ignite = optional_import("ignite") from monai.engines import CommonKeys, SupervisedTrainer from monai.utils import ThreadContainer - - has_ignite = True except ImportError: has_ignite = False From 6c0656e20638ab2a481866ed507b9b8ed3bc73a5 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 17 Mar 2021 21:48:42 +0000 Subject: [PATCH 10/11] Jupyter utilities update Signed-off-by: Eric Kerfoot --- tests/test_threadcontainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_threadcontainer.py b/tests/test_threadcontainer.py index 594b818eff..92a50a15aa 100644 --- a/tests/test_threadcontainer.py +++ b/tests/test_threadcontainer.py @@ -28,7 +28,7 @@ class TestThreadContainer(unittest.TestCase): - @unittest.skipIf(not has_ignite) + @unittest.skipIf(not has_ignite, "Ignite needed for this test") def test_container(self): net = torch.nn.Conv2d(1, 1, 3, padding=1) From cda427d18365272c1f89c88fdba9772253ff3b7b Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Fri, 19 Mar 2021 19:32:09 +0000 Subject: [PATCH 11/11] Update Signed-off-by: Eric Kerfoot --- monai/handlers/metric_logger.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py index 43a06722b2..c749d4bbab 100644 --- a/monai/handlers/metric_logger.py +++ b/monai/handlers/metric_logger.py @@ -14,6 +14,7 @@ from threading import RLock from typing import TYPE_CHECKING, Callable, DefaultDict, List, Optional +from monai.engines.utils import CommonKeys from monai.utils import exact_version, optional_import Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") @@ -23,8 +24,8 @@ Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") -def _get_loss_from_output(output): - return output["loss"].item() +def _get_loss_from_output(output, loss_key: str = CommonKeys.LOSS): + return output[loss_key].item() class MetricLoggerKeys(Enum):