From f21b007d84a0a17eab3c443aa4ba13bd147cb4e7 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 27 Sep 2021 08:12:20 +0800 Subject: [PATCH 1/2] [DLMED] update ignite 0.4.6 Signed-off-by: Nic Ma --- docs/requirements.txt | 2 +- requirements-dev.txt | 2 +- setup.cfg | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 47176c58a2..53eb6d3c0d 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,6 @@ -f https://download.pytorch.org/whl/cpu/torch-1.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl torch>=1.5 -pytorch-ignite==0.4.5 +pytorch-ignite==0.4.6 numpy>=1.17 itk>=5.2 nibabel diff --git a/requirements-dev.txt b/requirements-dev.txt index e28f85b244..254cb06d27 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,6 @@ # Full requirements for developments -r requirements-min.txt -pytorch-ignite==0.4.5 +pytorch-ignite==0.4.6 gdown>=3.6.4 scipy itk>=5.2 diff --git a/setup.cfg b/setup.cfg index eddb898b0d..309169ebe8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,7 +34,7 @@ all = pillow tensorboard gdown>=3.6.4 - pytorch-ignite==0.4.5 + pytorch-ignite==0.4.6 torchvision itk>=5.2 tqdm>=4.47.0 @@ -56,7 +56,7 @@ tensorboard = gdown = gdown>=3.6.4 ignite = - pytorch-ignite==0.4.5 + pytorch-ignite==0.4.6 torchvision = torchvision itk = From 757f255f6c3752e2f133bc81464c6c258acf24e9 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 28 Sep 2021 00:24:56 +0800 Subject: [PATCH 2/2] [DLMED] fix flake8 issues Signed-off-by: Nic Ma --- .../pathology/handlers/prob_map_producer.py | 7 ++++-- monai/engines/evaluator.py | 24 ++++++++++-------- monai/engines/multi_gpu_supervised_trainer.py | 4 +-- monai/engines/trainer.py | 24 ++++++++++++------ monai/engines/workflow.py | 25 +++++++++++-------- monai/handlers/checkpoint_loader.py | 2 +- monai/handlers/checkpoint_saver.py | 8 +++--- monai/handlers/decollate_batch.py | 4 +-- monai/handlers/garbage_collector.py | 1 + monai/handlers/ignite_metric.py | 4 +-- monai/handlers/metrics_saver.py | 10 +++++--- monai/handlers/nvtx_handlers.py | 22 ++++++---------- monai/handlers/stats_handler.py | 9 ++++--- monai/handlers/utils.py | 2 +- monai/utils/jupyter_utils.py | 18 ++++++------- 15 files changed, 91 insertions(+), 73 deletions(-) diff --git a/monai/apps/pathology/handlers/prob_map_producer.py b/monai/apps/pathology/handlers/prob_map_producer.py index 7ac4a0e45b..469e9d3c25 100644 --- a/monai/apps/pathology/handlers/prob_map_producer.py +++ b/monai/apps/pathology/handlers/prob_map_producer.py @@ -62,9 +62,10 @@ def attach(self, engine: Engine) -> None: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - self.num_images = len(engine.data_loader.dataset.data) + data_loader = engine.data_loader # type: ignore + self.num_images = len(data_loader.dataset.data) - for sample in engine.data_loader.dataset.data: + for sample in data_loader.dataset.data: name = sample["name"] self.prob_map[name] = np.zeros(sample["mask_shape"], dtype=self.dtype) self.counter[name] = len(sample["mask_locations"]) @@ -84,6 +85,8 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ + if not isinstance(engine.state.batch, dict) or not isinstance(engine.state.output, dict): + raise ValueError("engine.state.batch and engine.state.output must be dictionaries.") names = engine.state.batch["name"] locs = engine.state.batch["mask_location"] pred = engine.state.output["pred"] diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 1c37da71d4..bfe9d01e1f 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -219,7 +219,7 @@ def __init__( self.network = network self.inferer = SimpleInferer() if inferer is None else inferer - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: @@ -237,7 +237,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) + batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -246,15 +246,15 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict inputs, targets, args, kwargs = batch # put iteration outputs into engine.state - engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore # execute forward computation with self.mode(self.network): if self.amp: with torch.cuda.amp.autocast(): - engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) + engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore else: - engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) + engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) @@ -349,7 +349,7 @@ def __init__( self.pred_keys = ensure_tuple(pred_keys) self.inferer = SimpleInferer() if inferer is None else inferer - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: @@ -370,7 +370,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) + batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -379,17 +379,21 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict inputs, targets, args, kwargs = batch # put iteration outputs into engine.state - engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore for idx, network in enumerate(self.networks): with self.mode(network): if self.amp: with torch.cuda.amp.autocast(): + if isinstance(engine.state.output, dict): + engine.state.output.update( + {self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)} + ) + else: + if isinstance(engine.state.output, dict): engine.state.output.update( {self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)} ) - else: - engine.state.output.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}) engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) diff --git a/monai/engines/multi_gpu_supervised_trainer.py b/monai/engines/multi_gpu_supervised_trainer.py index 3671dbcfd1..b6f516ff99 100644 --- a/monai/engines/multi_gpu_supervised_trainer.py +++ b/monai/engines/multi_gpu_supervised_trainer.py @@ -59,7 +59,7 @@ def create_multigpu_supervised_trainer( prepare_batch: Callable = _prepare_batch, output_transform: Callable = _default_transform, distributed: bool = False, -) -> Engine: +): """ Derived from `create_supervised_trainer` in Ignite. @@ -107,7 +107,7 @@ def create_multigpu_supervised_evaluator( prepare_batch: Callable = _prepare_batch, output_transform: Callable = _default_eval_transform, distributed: bool = False, -) -> Engine: +): """ Derived from `create_supervised_evaluator` in Ignite. diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 44e265be1f..eeda143def 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -172,7 +172,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) + batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -180,7 +180,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): else: inputs, targets, args, kwargs = batch # put iteration outputs into engine.state - engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore def _compute_pred_loss(): engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) @@ -198,13 +198,13 @@ def _compute_pred_loss(): if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): _compute_pred_loss() - self.scaler.scale(engine.state.output[Keys.LOSS]).backward() + self.scaler.scale(engine.state.output[Keys.LOSS]).backward() # type: ignore engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.scaler.step(self.optimizer) self.scaler.update() else: _compute_pred_loss() - engine.state.output[Keys.LOSS].backward() + engine.state.output[Keys.LOSS].backward() # type: ignore engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.optimizer.step() engine.fire_event(IterationEvents.MODEL_COMPLETED) @@ -345,9 +345,14 @@ def _iteration( if batchdata is None: raise ValueError("must provide batch data for current iteration.") - d_input = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) + d_input = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore batch_size = self.data_loader.batch_size # type: ignore - g_input = self.g_prepare_batch(batch_size, self.latent_shape, engine.state.device, engine.non_blocking) + g_input = self.g_prepare_batch( + num_latents=batch_size, + latent_size=self.latent_shape, + device=engine.state.device, # type: ignore + non_blocking=engine.non_blocking, # type: ignore + ) g_output = self.g_inferer(g_input, self.g_network) # Train Discriminator @@ -367,7 +372,12 @@ def _iteration( # Train Generator if self.g_update_latents: - g_input = self.g_prepare_batch(batch_size, self.latent_shape, engine.state.device, engine.non_blocking) + g_input = self.g_prepare_batch( + num_latents=batch_size, + latent_size=self.latent_shape, + device=engine.state.device, # type: ignore + non_blocking=engine.non_blocking, # type: ignore + ) g_output = self.g_inferer(g_input, self.g_network) if PT_BEFORE_1_7: self.g_optimizer.zero_grad() diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index ffb8ce05b3..3454095a02 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -152,15 +152,15 @@ def set_sampler_epoch(engine: Engine): self.scaler: Optional[torch.cuda.amp.GradScaler] = None if event_names is None: - event_names = [IterationEvents] + event_names = [IterationEvents] # type: ignore else: if not isinstance(event_names, list): raise ValueError("event_names must be a list or string or EventEnum.") - event_names += [IterationEvents] + event_names += [IterationEvents] # type: ignore for name in event_names: if isinstance(name, str): self.register_events(name, event_to_attr=event_to_attr) - elif issubclass(name, EventEnum): + elif issubclass(name, EventEnum): # type: ignore self.register_events(*name, event_to_attr=event_to_attr) else: raise ValueError("event_names must be a list or string or EventEnum.") @@ -187,8 +187,10 @@ def _register_decollate(self): def _decollate_data(engine: Engine) -> None: # replicate the scalar values to make sure all the items have batch dimension, then decollate transform = Decollated(keys=None, detach=True) - engine.state.batch = transform(engine.state.batch) - engine.state.output = transform(engine.state.output) + if isinstance(engine.state.batch, (list, dict)): + engine.state.batch = transform(engine.state.batch) + if isinstance(engine.state.output, (list, dict)): + engine.state.output = transform(engine.state.output) def _register_postprocessing(self, posttrans: Callable): """ @@ -226,12 +228,13 @@ def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None): @self.on(Events.EPOCH_COMPLETED) def _compare_metrics(engine: Engine) -> None: - if engine.state.key_metric_name is not None: - current_val_metric = engine.state.metrics[engine.state.key_metric_name] - if self.metric_cmp_fn(current_val_metric, engine.state.best_metric): - self.logger.info(f"Got new best metric of {engine.state.key_metric_name}: {current_val_metric}") - engine.state.best_metric = current_val_metric - engine.state.best_metric_epoch = engine.state.epoch + key_metric_name = engine.state.key_metric_name # type: ignore + if key_metric_name is not None: + current_val_metric = engine.state.metrics[key_metric_name] + if self.metric_cmp_fn(current_val_metric, engine.state.best_metric): # type: ignore + self.logger.info(f"Got new best metric of {key_metric_name}: {current_val_metric}") + engine.state.best_metric = current_val_metric # type: ignore + engine.state.best_metric_epoch = engine.state.epoch # type: ignore def _register_handlers(self, handlers: Sequence): """ diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index f1f60abf63..7c30584b13 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -126,7 +126,7 @@ def __call__(self, engine: Engine) -> None: # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint prior_max_epochs = engine.state.max_epochs Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint, strict=self.strict) - if engine.state.epoch > prior_max_epochs: + if prior_max_epochs is not None and engine.state.epoch > prior_max_epochs: raise ValueError( f"Epoch count ({engine.state.epoch}) in checkpoint is larger than " f"the `engine.state.max_epochs` ({prior_max_epochs}) of engine. To further train from checkpoint, " diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index f365ff73c4..d5aadadfed 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -11,7 +11,7 @@ import logging import warnings -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, Mapping, Optional from monai.config import IgniteInfo from monai.utils import min_version, optional_import @@ -126,7 +126,7 @@ def __init__(self, dirname: str, filename: Optional[str] = None): super().__init__(dirname=dirname, require_empty=False, atomic=False) self.filename = filename - def __call__(self, checkpoint: Dict, filename: str, metadata: Optional[Dict] = None) -> None: + def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mapping] = None) -> None: if self.filename is not None: filename = self.filename super().__call__(checkpoint=checkpoint, filename=filename, metadata=metadata) @@ -154,8 +154,8 @@ def _final_func(engine: Engine): def _score_func(engine: Engine): if isinstance(key_metric_name, str): metric_name = key_metric_name - elif hasattr(engine.state, "key_metric_name") and isinstance(engine.state.key_metric_name, str): - metric_name = engine.state.key_metric_name + elif hasattr(engine.state, "key_metric_name"): + metric_name = engine.state.key_metric_name # type: ignore else: raise ValueError( f"Incompatible values: save_key_metric=True and key_metric_name={key_metric_name}." diff --git a/monai/handlers/decollate_batch.py b/monai/handlers/decollate_batch.py index 4e99fc6f04..0905ee6ebc 100644 --- a/monai/handlers/decollate_batch.py +++ b/monai/handlers/decollate_batch.py @@ -88,7 +88,7 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - if self.batch_transform is not None: + if self.batch_transform is not None and isinstance(engine.state.batch, (list, dict)): engine.state.batch = self.batch_transform(engine.state.batch) - if self.output_transform is not None: + if self.output_transform is not None and isinstance(engine.state.output, (list, dict)): engine.state.output = self.output_transform(engine.state.output) diff --git a/monai/handlers/garbage_collector.py b/monai/handlers/garbage_collector.py index fffca2a740..1eb970e795 100644 --- a/monai/handlers/garbage_collector.py +++ b/monai/handlers/garbage_collector.py @@ -42,6 +42,7 @@ class GarbageCollector: """ def __init__(self, trigger_event: str = "epoch", log_level: int = 10): + self.trigger_event: Events if isinstance(trigger_event, Events): self.trigger_event = trigger_event elif trigger_event.lower() == "epoch": diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py index cbf84e4626..ea7bcd8eee 100644 --- a/monai/handlers/ignite_metric.py +++ b/monai/handlers/ignite_metric.py @@ -101,7 +101,7 @@ def compute(self) -> Any: if self.save_details: if self._engine is None or self._name is None: raise RuntimeError("please call the attach() function to connect expected engine first.") - self._engine.state.metric_details[self._name] = self.metric_fn.get_buffer() + self._engine.state.metric_details[self._name] = self.metric_fn.get_buffer() # type: ignore return result.item() if isinstance(result, torch.Tensor) else result @@ -120,4 +120,4 @@ def attach(self, engine: Engine, name: str) -> None: self._engine = engine self._name = name if self.save_details and not hasattr(engine.state, "metric_details"): - engine.state.metric_details = {} + engine.state.metric_details = {} # type: ignore diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index 97b080b244..4c722eb35b 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -132,10 +132,12 @@ def __call__(self, engine: Engine) -> None: if self.metrics is not None and len(engine.state.metrics) > 0: _metrics = {k: v for k, v in engine.state.metrics.items() if k in self.metrics or "*" in self.metrics} _metric_details = {} - if self.metric_details is not None and len(engine.state.metric_details) > 0: - for k, v in engine.state.metric_details.items(): - if k in self.metric_details or "*" in self.metric_details: - _metric_details[k] = v + if hasattr(engine.state, "metric_details"): + details = engine.state.metric_details # type: ignore + if self.metric_details is not None and len(details) > 0: + for k, v in details.items(): + if k in self.metric_details or "*" in self.metric_details: + _metric_details[k] = v write_metrics_reports( save_dir=self.save_dir, diff --git a/monai/handlers/nvtx_handlers.py b/monai/handlers/nvtx_handlers.py index aba7a7ec0e..847a3c0c47 100644 --- a/monai/handlers/nvtx_handlers.py +++ b/monai/handlers/nvtx_handlers.py @@ -97,9 +97,7 @@ def create_paired_events(self, event: str) -> Tuple[Events, Events]: ) def get_event(self, event: Union[str, Events]) -> Events: - if isinstance(event, str): - event = event.upper() - return Events[event] + return Events[event.upper()] if isinstance(event, str) else event def attach(self, engine: Engine) -> None: """ @@ -126,10 +124,8 @@ class RangePushHandler: msg: ASCII message to associate with range """ - def __init__(self, event: Events, msg: Optional[str] = None) -> None: - if isinstance(event, str): - event = event.upper() - self.event = Events[event] + def __init__(self, event: Union[str, Events], msg: Optional[str] = None) -> None: + self.event = Events[event.upper()] if isinstance(event, str) else event if msg is None: msg = self.event.name self.msg = msg @@ -156,10 +152,8 @@ class RangePopHandler: msg: ASCII message to associate with range """ - def __init__(self, event: Events) -> None: - if isinstance(event, str): - event = event.upper() - self.event = Events[event] + def __init__(self, event: Union[str, Events]) -> None: + self.event = Events[event.upper()] if isinstance(event, str) else event def attach(self, engine: Engine) -> None: """ @@ -181,10 +175,8 @@ class MarkHandler: msg: ASCII message to associate with range """ - def __init__(self, event: Events, msg: Optional[str] = None) -> None: - if isinstance(event, str): - event = event.upper() - self.event = Events[event] + def __init__(self, event: Union[str, Events], msg: Optional[str] = None) -> None: + self.event = Events[event.upper()] if isinstance(event, str) else event if msg is None: msg = self.event.name self.msg = msg diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index d5756074fc..c15abac542 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -172,8 +172,9 @@ def _default_epoch_print(self, engine: Engine) -> None: and hasattr(engine.state, "best_metric") and hasattr(engine.state, "best_metric_epoch") ): - out_str = f"Key metric: {engine.state.key_metric_name} " - out_str += f"best value: {engine.state.best_metric} at epoch: {engine.state.best_metric_epoch}" + out_str = f"Key metric: {engine.state.key_metric_name} " # type: ignore + out_str += f"best value: {engine.state.best_metric} " # type: ignore + out_str += f"at epoch: {engine.state.best_metric_epoch}" # type: ignore self.logger.info(out_str) def _default_iteration_print(self, engine: Engine) -> None: @@ -220,7 +221,9 @@ def _default_iteration_print(self, engine: Engine) -> None: return # no value to print num_iterations = engine.state.epoch_length - current_iteration = (engine.state.iteration - 1) % num_iterations + 1 + current_iteration = engine.state.iteration - 1 + if num_iterations is not None: + current_iteration %= num_iterations + 1 current_epoch = engine.state.epoch num_epochs = engine.state.max_epochs diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 60f95d458e..5d72c028f9 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -50,7 +50,7 @@ def stopping_fn_from_loss(): """ def stopping_fn(engine: Engine): - return -engine.state.output + return -engine.state.output # type:ignore return stopping_fn diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index d6d127b52f..f862452fb1 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -16,11 +16,14 @@ from enum import Enum from threading import RLock, Thread -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch +from monai.config import IgniteInfo +from monai.utils.module import min_version, optional_import + try: import matplotlib.pyplot as plt @@ -28,14 +31,11 @@ except ImportError: has_matplotlib = False -try: +if TYPE_CHECKING: from ignite.engine import Engine, Events - - has_ignite = True -except ImportError: - Engine = object - Events = object - has_ignite = False +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") LOSS_NAME = "loss" @@ -190,7 +190,7 @@ def plot_engine_status( graphmap = {LOSS_NAME: logger.loss} graphmap.update(logger.metrics) - imagemap = {} + imagemap: Dict = {} if image_fn is not None and engine.state is not None and engine.state.batch is not None: for src in (engine.state.batch, engine.state.output): label = "Batch" if src is engine.state.batch else "Output"