Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ Generic Interfaces
.. autoclass:: ImageDataset
:members:
:special-members: __getitem__

`NPZDictItemDataset`
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: NPZDictItemDataset
:members:
:special-members: __getitem__

Patch-based dataset
-------------------
Expand Down
3 changes: 2 additions & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
CacheNTransDataset,
Dataset,
LMDBDataset,
NPZDictItemDataset,
PersistentDataset,
SmartCacheDataset,
ZipDataset,
Expand All @@ -32,7 +33,7 @@
from .png_writer import write_png
from .samplers import DistributedSampler, DistributedWeightedRandomSampler
from .synthetic import create_test_image_2d, create_test_image_3d
from .thread_buffer import ThreadBuffer
from .thread_buffer import ThreadBuffer, ThreadDataLoader
from .utils import (
compute_importance_map,
compute_shape_offset,
Expand Down
55 changes: 53 additions & 2 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
from copy import deepcopy
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union
from typing import IO, TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union

import numpy as np
import torch
from torch.utils.data import Dataset as _TorchDataset

from monai.data.utils import pickle_hashing
from monai.data.utils import first, pickle_hashing
from monai.transforms import Compose, Randomizable, Transform, apply_transform
from monai.transforms.transform import RandomizableTransform
from monai.utils import MAX_SEED, get_seed, min_version, optional_import
Expand Down Expand Up @@ -931,3 +932,53 @@ def __getitem__(self, index: int):
if isinstance(transform, RandomizableTransform):
transform.set_random_state(seed=self._seed)
return self.dataset[index]


class NPZDictItemDataset(Dataset):
"""
Represents a dataset from a loaded NPZ file. The members of the file to load are named in the keys of `keys` and
stored under the keyed name. All loaded arrays must have the same 0-dimension (batch) size. Items are always dicts
mapping names to an item extracted from the loaded arrays.

Args:
npzfile: Path to .npz file or stream containing .npz file data
keys: Maps keys to load from file to name to store in dataset
transform: Transform to apply to batch dict
other_keys: secondary data to load from file and store in dict `other_keys`, not returned by __getitem__
"""

def __init__(
self,
npzfile: Union[str, IO],
keys: Dict[str, str],
transform: Optional[Callable] = None,
other_keys: Optional[Sequence[str]] = (),
):
self.npzfile: Union[str, IO] = npzfile if isinstance(npzfile, str) else "STREAM"
self.keys: Dict[str, str] = dict(keys)
dat = np.load(npzfile)

self.arrays = {storedk: dat[datak] for datak, storedk in self.keys.items()}
self.length = self.arrays[first(self.keys.values())].shape[0]

self.other_keys = {} if other_keys is None else {k: dat[k] for k in other_keys}

for k, v in self.arrays.items():
if v.shape[0] != self.length:
raise ValueError(
"All loaded arrays must have the same first dimension "
f"size {self.length}, array `{k}` has size {v.shape[0]}"
)

super().__init__([], transform)

def __len__(self):
return self.length

def __getitem__(self, index: int):
data = {k: v[index] for k, v in self.arrays.items()}

if self.transform is not None:
data = apply_transform(self.transform, data)

return data
19 changes: 19 additions & 0 deletions monai/data/thread_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from queue import Empty, Full, Queue
from threading import Thread

from monai.data import DataLoader, Dataset


class ThreadBuffer:
"""
Expand Down Expand Up @@ -73,3 +75,20 @@ def __iter__(self):
pass # queue was empty this time, try again
finally:
self.stop() # ensure thread completion


class ThreadDataLoader(DataLoader):
"""
Subclass of `DataLoader` using a `ThreadBuffer` object to implement `__iter__` method asynchronously. This will
iterate over data from the loader as expected however the data is generated on a separate thread. Use this class
where a `DataLoader` instance is required and not just an iterable object.
"""

def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs):
super().__init__(dataset, num_workers, **kwargs)

# ThreadBuffer will use the inherited __iter__ instead of the one defined below
self.buffer = ThreadBuffer(super())

def __iter__(self):
yield from self.buffer
2 changes: 1 addition & 1 deletion monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .iteration_metric import IterationMetric
from .lr_schedule_handler import LrScheduleHandler
from .mean_dice import MeanDice
from .metric_logger import MetricLogger
from .metric_logger import MetricLogger, MetricLoggerKeys
from .metrics_saver import MetricsSaver
from .roc_auc import ROCAUC
from .segmentation_saver import SegmentationSaver
Expand Down
79 changes: 69 additions & 10 deletions monai/handlers/metric_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
# limitations under the License.

from collections import defaultdict
from typing import TYPE_CHECKING, Callable, DefaultDict, List
from enum import Enum
from threading import RLock
from typing import TYPE_CHECKING, Callable, DefaultDict, List, Optional

from monai.engines.utils import CommonKeys
from monai.utils import exact_version, optional_import

Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
Expand All @@ -21,12 +24,43 @@
Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine")


def _get_loss_from_output(output, loss_key: str = CommonKeys.LOSS):
return output[loss_key].item()


class MetricLoggerKeys(Enum):
METRICS = "Metrics"
LOSS = "Loss"


class MetricLogger:
def __init__(self, loss_transform: Callable = lambda x: x, metric_transform: Callable = lambda x: x) -> None:
"""
Collect per-iteration metrics and loss value from the attached trainer. This will also collect metric values from
a given evaluator object which is expected to perform evaluation at the end of training epochs. This class is
useful for collecting loss and metric values in one place for storage with checkpoint savers (`state_dict` and
`load_state_dict` methods provided as expected by Pytorch and Ignite) and for graphing during training.

Args:
loss_transform: Converts the `output` value from the trainer's state into a loss value
metric_transform: Converts the metric value coming from the trainer/evaluator's state into a storable value
evaluator: Optional evaluator to consume metric results from at the end of its evaluation run
"""

def __init__(
self,
loss_transform: Callable = _get_loss_from_output,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be overkill but could be:

Suggested change
loss_transform: Callable = _get_loss_from_output,
loss_transform: Callable = partial(_get_loss_from_output, loss_key=MetricLoggerKeys.LOSS),

metric_transform: Callable = lambda x: x,
evaluator: Optional[Engine] = None,
) -> None:
self.loss_transform = loss_transform
self.metric_transform = metric_transform
self.loss: List = []
self.metrics: DefaultDict = defaultdict(list)
self.iteration = 0
self.lock = RLock()

if evaluator is not None:
self.attach_evaluator(evaluator)

def attach(self, engine: Engine) -> None:
"""
Expand All @@ -35,21 +69,46 @@ def attach(self, engine: Engine) -> None:
"""
engine.add_event_handler(Events.ITERATION_COMPLETED, self)

def attach_evaluator(self, evaluator: Engine) -> None:
"""
Attach event handlers to the given evaluator to log metric values from it.

Args:
evaluator: Ignite Engine implementing network evaluation
"""
evaluator.add_event_handler(Events.COMPLETED, self.log_metrics)

def __call__(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
self.loss.append(self.loss_transform(engine.state.output))
with self.lock:
self.iteration = engine.state.iteration
lossval = self.loss_transform(engine.state.output)

self.loss.append((self.iteration, lossval))
self.log_metrics(engine)

def log_metrics(self, engine: Engine) -> None:
"""
Log metrics from the given Engine's state member.

Args:
engine: Ignite Engine to log from
"""
with self.lock:
for m, v in engine.state.metrics.items():
v = self.metric_transform(v)
self.metrics[m].append((self.iteration, v))

for m, v in engine.state.metrics.items():
v = self.metric_transform(v)
# # metrics may not be added on the first timestep, pad the list if this is the case
# # so that each metric list is the same length as self.loss
# if len(self.metrics[m])==0:
# self.metrics[m].append([v[0]]*len(self.loss))
def state_dict(self):
return {MetricLoggerKeys.LOSS: self.loss, MetricLoggerKeys.METRICS: self.metrics}

self.metrics[m].append(v)
def load_state_dict(self, state_dict):
self.loss[:] = state_dict[MetricLoggerKeys.LOSS]
self.metrics.clear()
self.metrics.update(state_dict[MetricLoggerKeys.METRICS])


metriclogger = MetricLogger
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
UpsampleMode,
Weight,
)
from .jupyter_utils import StatusMembers, ThreadContainer
from .misc import (
MAX_SEED,
ImageMetaKey,
Expand Down
Loading