Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
-f https://download.pytorch.org/whl/cpu/torch-1.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl
torch>=1.5
pytorch-ignite==0.4.5
pytorch-ignite==0.4.6
numpy>=1.17
itk>=5.2
nibabel
Expand Down
7 changes: 5 additions & 2 deletions monai/apps/pathology/handlers/prob_map_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ def attach(self, engine: Engine) -> None:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""

self.num_images = len(engine.data_loader.dataset.data)
data_loader = engine.data_loader # type: ignore
self.num_images = len(data_loader.dataset.data)

for sample in engine.data_loader.dataset.data:
for sample in data_loader.dataset.data:
name = sample["name"]
self.prob_map[name] = np.zeros(sample["mask_shape"], dtype=self.dtype)
self.counter[name] = len(sample["mask_locations"])
Expand All @@ -84,6 +85,8 @@ def __call__(self, engine: Engine) -> None:
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
if not isinstance(engine.state.batch, dict) or not isinstance(engine.state.output, dict):
raise ValueError("engine.state.batch and engine.state.output must be dictionaries.")
names = engine.state.batch["name"]
locs = engine.state.batch["mask_location"]
pred = engine.state.output["pred"]
Expand Down
24 changes: 14 additions & 10 deletions monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __init__(
self.network = network
self.inferer = SimpleInferer() if inferer is None else inferer

def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
"""
callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
Return below items in a dictionary:
Expand All @@ -237,7 +237,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict
"""
if batchdata is None:
raise ValueError("Must provide batch data for current iteration.")
batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking)
batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore
if len(batch) == 2:
inputs, targets = batch
args: Tuple = ()
Expand All @@ -246,15 +246,15 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict
inputs, targets, args, kwargs = batch

# put iteration outputs into engine.state
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore

# execute forward computation
with self.mode(self.network):
if self.amp:
with torch.cuda.amp.autocast():
engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs)
engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore
else:
engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs)
engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
engine.fire_event(IterationEvents.MODEL_COMPLETED)

Expand Down Expand Up @@ -349,7 +349,7 @@ def __init__(
self.pred_keys = ensure_tuple(pred_keys)
self.inferer = SimpleInferer() if inferer is None else inferer

def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
"""
callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
Return below items in a dictionary:
Expand All @@ -370,7 +370,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict
"""
if batchdata is None:
raise ValueError("Must provide batch data for current iteration.")
batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking)
batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore
if len(batch) == 2:
inputs, targets = batch
args: Tuple = ()
Expand All @@ -379,17 +379,21 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict
inputs, targets, args, kwargs = batch

# put iteration outputs into engine.state
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore

for idx, network in enumerate(self.networks):
with self.mode(network):
if self.amp:
with torch.cuda.amp.autocast():
if isinstance(engine.state.output, dict):
engine.state.output.update(
{self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}
)
else:
if isinstance(engine.state.output, dict):
engine.state.output.update(
{self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}
)
else:
engine.state.output.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)})
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
engine.fire_event(IterationEvents.MODEL_COMPLETED)

Expand Down
4 changes: 2 additions & 2 deletions monai/engines/multi_gpu_supervised_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def create_multigpu_supervised_trainer(
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = _default_transform,
distributed: bool = False,
) -> Engine:
):
"""
Derived from `create_supervised_trainer` in Ignite.

Expand Down Expand Up @@ -107,7 +107,7 @@ def create_multigpu_supervised_evaluator(
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = _default_eval_transform,
distributed: bool = False,
) -> Engine:
):
"""
Derived from `create_supervised_evaluator` in Ignite.

Expand Down
24 changes: 17 additions & 7 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,15 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
"""
if batchdata is None:
raise ValueError("Must provide batch data for current iteration.")
batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking)
batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore
if len(batch) == 2:
inputs, targets = batch
args: Tuple = ()
kwargs: Dict = {}
else:
inputs, targets, args, kwargs = batch
# put iteration outputs into engine.state
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore

def _compute_pred_loss():
engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs)
Expand All @@ -198,13 +198,13 @@ def _compute_pred_loss():
if self.amp and self.scaler is not None:
with torch.cuda.amp.autocast():
_compute_pred_loss()
self.scaler.scale(engine.state.output[Keys.LOSS]).backward()
self.scaler.scale(engine.state.output[Keys.LOSS]).backward() # type: ignore
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
_compute_pred_loss()
engine.state.output[Keys.LOSS].backward()
engine.state.output[Keys.LOSS].backward() # type: ignore
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
self.optimizer.step()
engine.fire_event(IterationEvents.MODEL_COMPLETED)
Expand Down Expand Up @@ -345,9 +345,14 @@ def _iteration(
if batchdata is None:
raise ValueError("must provide batch data for current iteration.")

d_input = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking)
d_input = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) # type: ignore
batch_size = self.data_loader.batch_size # type: ignore
g_input = self.g_prepare_batch(batch_size, self.latent_shape, engine.state.device, engine.non_blocking)
g_input = self.g_prepare_batch(
num_latents=batch_size,
latent_size=self.latent_shape,
device=engine.state.device, # type: ignore
non_blocking=engine.non_blocking, # type: ignore
)
g_output = self.g_inferer(g_input, self.g_network)

# Train Discriminator
Expand All @@ -367,7 +372,12 @@ def _iteration(

# Train Generator
if self.g_update_latents:
g_input = self.g_prepare_batch(batch_size, self.latent_shape, engine.state.device, engine.non_blocking)
g_input = self.g_prepare_batch(
num_latents=batch_size,
latent_size=self.latent_shape,
device=engine.state.device, # type: ignore
non_blocking=engine.non_blocking, # type: ignore
)
g_output = self.g_inferer(g_input, self.g_network)
if PT_BEFORE_1_7:
self.g_optimizer.zero_grad()
Expand Down
25 changes: 14 additions & 11 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@ def set_sampler_epoch(engine: Engine):
self.scaler: Optional[torch.cuda.amp.GradScaler] = None

if event_names is None:
event_names = [IterationEvents]
event_names = [IterationEvents] # type: ignore
else:
if not isinstance(event_names, list):
raise ValueError("event_names must be a list or string or EventEnum.")
event_names += [IterationEvents]
event_names += [IterationEvents] # type: ignore
for name in event_names:
if isinstance(name, str):
self.register_events(name, event_to_attr=event_to_attr)
elif issubclass(name, EventEnum):
elif issubclass(name, EventEnum): # type: ignore
self.register_events(*name, event_to_attr=event_to_attr)
else:
raise ValueError("event_names must be a list or string or EventEnum.")
Expand All @@ -187,8 +187,10 @@ def _register_decollate(self):
def _decollate_data(engine: Engine) -> None:
# replicate the scalar values to make sure all the items have batch dimension, then decollate
transform = Decollated(keys=None, detach=True)
engine.state.batch = transform(engine.state.batch)
engine.state.output = transform(engine.state.output)
if isinstance(engine.state.batch, (list, dict)):
engine.state.batch = transform(engine.state.batch)
if isinstance(engine.state.output, (list, dict)):
engine.state.output = transform(engine.state.output)

def _register_postprocessing(self, posttrans: Callable):
"""
Expand Down Expand Up @@ -226,12 +228,13 @@ def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None):

@self.on(Events.EPOCH_COMPLETED)
def _compare_metrics(engine: Engine) -> None:
if engine.state.key_metric_name is not None:
current_val_metric = engine.state.metrics[engine.state.key_metric_name]
if self.metric_cmp_fn(current_val_metric, engine.state.best_metric):
self.logger.info(f"Got new best metric of {engine.state.key_metric_name}: {current_val_metric}")
engine.state.best_metric = current_val_metric
engine.state.best_metric_epoch = engine.state.epoch
key_metric_name = engine.state.key_metric_name # type: ignore
if key_metric_name is not None:
current_val_metric = engine.state.metrics[key_metric_name]
if self.metric_cmp_fn(current_val_metric, engine.state.best_metric): # type: ignore
self.logger.info(f"Got new best metric of {key_metric_name}: {current_val_metric}")
engine.state.best_metric = current_val_metric # type: ignore
engine.state.best_metric_epoch = engine.state.epoch # type: ignore

def _register_handlers(self, handlers: Sequence):
"""
Expand Down
2 changes: 1 addition & 1 deletion monai/handlers/checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __call__(self, engine: Engine) -> None:
# save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint
prior_max_epochs = engine.state.max_epochs
Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint, strict=self.strict)
if engine.state.epoch > prior_max_epochs:
if prior_max_epochs is not None and engine.state.epoch > prior_max_epochs:
raise ValueError(
f"Epoch count ({engine.state.epoch}) in checkpoint is larger than "
f"the `engine.state.max_epochs` ({prior_max_epochs}) of engine. To further train from checkpoint, "
Expand Down
8 changes: 4 additions & 4 deletions monai/handlers/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import logging
import warnings
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Dict, Mapping, Optional

from monai.config import IgniteInfo
from monai.utils import min_version, optional_import
Expand Down Expand Up @@ -126,7 +126,7 @@ def __init__(self, dirname: str, filename: Optional[str] = None):
super().__init__(dirname=dirname, require_empty=False, atomic=False)
self.filename = filename

def __call__(self, checkpoint: Dict, filename: str, metadata: Optional[Dict] = None) -> None:
def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mapping] = None) -> None:
if self.filename is not None:
filename = self.filename
super().__call__(checkpoint=checkpoint, filename=filename, metadata=metadata)
Expand Down Expand Up @@ -154,8 +154,8 @@ def _final_func(engine: Engine):
def _score_func(engine: Engine):
if isinstance(key_metric_name, str):
metric_name = key_metric_name
elif hasattr(engine.state, "key_metric_name") and isinstance(engine.state.key_metric_name, str):
metric_name = engine.state.key_metric_name
elif hasattr(engine.state, "key_metric_name"):
metric_name = engine.state.key_metric_name # type: ignore
else:
raise ValueError(
f"Incompatible values: save_key_metric=True and key_metric_name={key_metric_name}."
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/decollate_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __call__(self, engine: Engine) -> None:
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
if self.batch_transform is not None:
if self.batch_transform is not None and isinstance(engine.state.batch, (list, dict)):
engine.state.batch = self.batch_transform(engine.state.batch)
if self.output_transform is not None:
if self.output_transform is not None and isinstance(engine.state.output, (list, dict)):
engine.state.output = self.output_transform(engine.state.output)
1 change: 1 addition & 0 deletions monai/handlers/garbage_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class GarbageCollector:
"""

def __init__(self, trigger_event: str = "epoch", log_level: int = 10):
self.trigger_event: Events
if isinstance(trigger_event, Events):
self.trigger_event = trigger_event
elif trigger_event.lower() == "epoch":
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/ignite_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def compute(self) -> Any:
if self.save_details:
if self._engine is None or self._name is None:
raise RuntimeError("please call the attach() function to connect expected engine first.")
self._engine.state.metric_details[self._name] = self.metric_fn.get_buffer()
self._engine.state.metric_details[self._name] = self.metric_fn.get_buffer() # type: ignore

return result.item() if isinstance(result, torch.Tensor) else result

Expand All @@ -120,4 +120,4 @@ def attach(self, engine: Engine, name: str) -> None:
self._engine = engine
self._name = name
if self.save_details and not hasattr(engine.state, "metric_details"):
engine.state.metric_details = {}
engine.state.metric_details = {} # type: ignore
10 changes: 6 additions & 4 deletions monai/handlers/metrics_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,12 @@ def __call__(self, engine: Engine) -> None:
if self.metrics is not None and len(engine.state.metrics) > 0:
_metrics = {k: v for k, v in engine.state.metrics.items() if k in self.metrics or "*" in self.metrics}
_metric_details = {}
if self.metric_details is not None and len(engine.state.metric_details) > 0:
for k, v in engine.state.metric_details.items():
if k in self.metric_details or "*" in self.metric_details:
_metric_details[k] = v
if hasattr(engine.state, "metric_details"):
details = engine.state.metric_details # type: ignore
if self.metric_details is not None and len(details) > 0:
for k, v in details.items():
if k in self.metric_details or "*" in self.metric_details:
_metric_details[k] = v

write_metrics_reports(
save_dir=self.save_dir,
Expand Down
22 changes: 7 additions & 15 deletions monai/handlers/nvtx_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ def create_paired_events(self, event: str) -> Tuple[Events, Events]:
)

def get_event(self, event: Union[str, Events]) -> Events:
if isinstance(event, str):
event = event.upper()
return Events[event]
return Events[event.upper()] if isinstance(event, str) else event

def attach(self, engine: Engine) -> None:
"""
Expand All @@ -126,10 +124,8 @@ class RangePushHandler:
msg: ASCII message to associate with range
"""

def __init__(self, event: Events, msg: Optional[str] = None) -> None:
if isinstance(event, str):
event = event.upper()
self.event = Events[event]
def __init__(self, event: Union[str, Events], msg: Optional[str] = None) -> None:
self.event = Events[event.upper()] if isinstance(event, str) else event
if msg is None:
msg = self.event.name
self.msg = msg
Expand All @@ -156,10 +152,8 @@ class RangePopHandler:
msg: ASCII message to associate with range
"""

def __init__(self, event: Events) -> None:
if isinstance(event, str):
event = event.upper()
self.event = Events[event]
def __init__(self, event: Union[str, Events]) -> None:
self.event = Events[event.upper()] if isinstance(event, str) else event

def attach(self, engine: Engine) -> None:
"""
Expand All @@ -181,10 +175,8 @@ class MarkHandler:
msg: ASCII message to associate with range
"""

def __init__(self, event: Events, msg: Optional[str] = None) -> None:
if isinstance(event, str):
event = event.upper()
self.event = Events[event]
def __init__(self, event: Union[str, Events], msg: Optional[str] = None) -> None:
self.event = Events[event.upper()] if isinstance(event, str) else event
if msg is None:
msg = self.event.name
self.msg = msg
Expand Down
Loading