Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions monai/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
GanKeys,
IterationEvents,
default_make_latent,
default_metric_cmp_fn,
default_prepare_batch,
engine_apply_transform,
get_devices_spec,
Expand Down
17 changes: 16 additions & 1 deletion monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.utils.data import DataLoader

from monai.config import IgniteInfo
from monai.engines.utils import IterationEvents, default_prepare_batch
from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
from monai.networks.utils import eval_mode, train_mode
Expand Down Expand Up @@ -53,6 +53,9 @@ class Evaluator(Workflow):
engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
checkpoint into files.
additional_metrics: more Ignite metrics that also attach to Ignite Engine.
metric_cmp_fn: function to compare current key metric with previous best key metric value,
it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
`best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
CheckpointHandler, StatsHandler, SegmentationSaver, etc.
amp: whether to enable auto-mixed-precision evaluation, default is False.
Expand All @@ -79,6 +82,7 @@ def __init__(
postprocessing: Optional[Transform] = None,
key_val_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
metric_cmp_fn: Callable = default_metric_cmp_fn,
val_handlers: Optional[Sequence] = None,
amp: bool = False,
mode: Union[ForwardMode, str] = ForwardMode.EVAL,
Expand All @@ -97,6 +101,7 @@ def __init__(
postprocessing=postprocessing,
key_metric=key_val_metric,
additional_metrics=additional_metrics,
metric_cmp_fn=metric_cmp_fn,
handlers=val_handlers,
amp=amp,
event_names=event_names,
Expand Down Expand Up @@ -150,6 +155,9 @@ class SupervisedEvaluator(Evaluator):
engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
checkpoint into files.
additional_metrics: more Ignite metrics that also attach to Ignite Engine.
metric_cmp_fn: function to compare current key metric with previous best key metric value,
it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
`best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
CheckpointHandler, StatsHandler, SegmentationSaver, etc.
amp: whether to enable auto-mixed-precision evaluation, default is False.
Expand Down Expand Up @@ -178,6 +186,7 @@ def __init__(
postprocessing: Optional[Transform] = None,
key_val_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
metric_cmp_fn: Callable = default_metric_cmp_fn,
val_handlers: Optional[Sequence] = None,
amp: bool = False,
mode: Union[ForwardMode, str] = ForwardMode.EVAL,
Expand All @@ -195,6 +204,7 @@ def __init__(
postprocessing=postprocessing,
key_val_metric=key_val_metric,
additional_metrics=additional_metrics,
metric_cmp_fn=metric_cmp_fn,
val_handlers=val_handlers,
amp=amp,
mode=mode,
Expand Down Expand Up @@ -272,6 +282,9 @@ class EnsembleEvaluator(Evaluator):
engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
checkpoint into files.
additional_metrics: more Ignite metrics that also attach to Ignite Engine.
metric_cmp_fn: function to compare current key metric with previous best key metric value,
it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
`best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
CheckpointHandler, StatsHandler, SegmentationSaver, etc.
amp: whether to enable auto-mixed-precision evaluation, default is False.
Expand Down Expand Up @@ -301,6 +314,7 @@ def __init__(
postprocessing: Optional[Transform] = None,
key_val_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
metric_cmp_fn: Callable = default_metric_cmp_fn,
val_handlers: Optional[Sequence] = None,
amp: bool = False,
mode: Union[ForwardMode, str] = ForwardMode.EVAL,
Expand All @@ -318,6 +332,7 @@ def __init__(
postprocessing=postprocessing,
key_val_metric=key_val_metric,
additional_metrics=additional_metrics,
metric_cmp_fn=metric_cmp_fn,
val_handlers=val_handlers,
amp=amp,
mode=mode,
Expand Down
18 changes: 17 additions & 1 deletion monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
from torch.utils.data import DataLoader

from monai.config import IgniteInfo
from monai.engines.utils import GanKeys, IterationEvents, default_make_latent, default_prepare_batch
from monai.engines.utils import (
GanKeys,
IterationEvents,
default_make_latent,
default_metric_cmp_fn,
default_prepare_batch,
)
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
from monai.transforms import Transform
Expand Down Expand Up @@ -79,6 +85,9 @@ class SupervisedTrainer(Trainer):
engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the
checkpoint into files.
additional_metrics: more Ignite metrics that also attach to Ignite Engine.
metric_cmp_fn: function to compare current key metric with previous best key metric value,
it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
`best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
CheckpointHandler, StatsHandler, SegmentationSaver, etc.
amp: whether to enable auto-mixed-precision training, default is False.
Expand Down Expand Up @@ -108,6 +117,7 @@ def __init__(
postprocessing: Optional[Transform] = None,
key_train_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
metric_cmp_fn: Callable = default_metric_cmp_fn,
train_handlers: Optional[Sequence] = None,
amp: bool = False,
event_names: Optional[List[Union[str, EventEnum]]] = None,
Expand All @@ -125,6 +135,7 @@ def __init__(
postprocessing=postprocessing,
key_metric=key_train_metric,
additional_metrics=additional_metrics,
metric_cmp_fn=metric_cmp_fn,
handlers=train_handlers,
amp=amp,
event_names=event_names,
Expand Down Expand Up @@ -232,6 +243,9 @@ class GanTrainer(Trainer):
engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the
checkpoint into files.
additional_metrics: more Ignite metrics that also attach to Ignite Engine.
metric_cmp_fn: function to compare current key metric with previous best key metric value,
it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
`best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
CheckpointHandler, StatsHandler, SegmentationSaver, etc.
decollate: whether to decollate the batch-first data to a list of data after model computation,
Expand Down Expand Up @@ -264,6 +278,7 @@ def __init__(
postprocessing: Optional[Transform] = None,
key_train_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
metric_cmp_fn: Callable = default_metric_cmp_fn,
train_handlers: Optional[Sequence] = None,
decollate: bool = True,
):
Expand All @@ -281,6 +296,7 @@ def __init__(
iteration_update=iteration_update,
key_metric=key_train_metric,
additional_metrics=additional_metrics,
metric_cmp_fn=metric_cmp_fn,
handlers=train_handlers,
postprocessing=postprocessing,
decollate=decollate,
Expand Down
13 changes: 13 additions & 0 deletions monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"default_prepare_batch",
"default_make_latent",
"engine_apply_transform",
"default_metric_cmp_fn",
]


Expand Down Expand Up @@ -158,3 +159,15 @@ def engine_apply_transform(batch: Any, output: Any, transform: Callable[..., Dic
output = apply_transform(transform, output)

return batch, output


def default_metric_cmp_fn(current_metric: float, prev_best: float) -> bool:
"""
The default function to compare metric values between current metric and previous best metric.

Args:
current_metric: metric value of current round computation.
prev_best: the best metric value of previous rounds to compare with.

"""
return current_metric > prev_best
9 changes: 7 additions & 2 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from monai.config import IgniteInfo
from monai.data import decollate_batch, rep_scalar_to_batch
from monai.engines.utils import IterationEvents, default_prepare_batch
from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
from monai.utils import ensure_tuple, min_version, optional_import

from .utils import engine_apply_transform
Expand Down Expand Up @@ -63,6 +63,9 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona
engine.state.metrics when epoch completed. key_metric is the main metric to compare and save the
checkpoint into files.
additional_metrics: more Ignite metrics that also attach to Ignite Engine.
metric_cmp_fn: function to compare current key metric with previous best key metric value,
it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
`best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
CheckpointHandler, StatsHandler, SegmentationSaver, etc.
amp: whether to enable auto-mixed-precision training or inference, default is False.
Expand Down Expand Up @@ -94,6 +97,7 @@ def __init__(
postprocessing: Optional[Callable] = None,
key_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
metric_cmp_fn: Callable = default_metric_cmp_fn,
handlers: Optional[Sequence] = None,
amp: bool = False,
event_names: Optional[List[Union[str, EventEnum]]] = None,
Expand Down Expand Up @@ -142,6 +146,7 @@ def set_sampler_epoch(engine: Engine):
self.data_loader = data_loader
self.non_blocking = non_blocking
self.prepare_batch = prepare_batch
self.metric_cmp_fn = metric_cmp_fn
self.amp = amp

if event_names is None:
Expand Down Expand Up @@ -214,7 +219,7 @@ def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None):
def _compare_metrics(engine: Engine) -> None:
if engine.state.key_metric_name is not None:
current_val_metric = engine.state.metrics[engine.state.key_metric_name]
if current_val_metric > engine.state.best_metric:
if self.metric_cmp_fn(current_val_metric, engine.state.best_metric):
self.logger.info(f"Got new best metric of {engine.state.key_metric_name}: {current_val_metric}")
engine.state.best_metric = current_val_metric
engine.state.best_metric_epoch = engine.state.epoch
Expand Down
11 changes: 9 additions & 2 deletions monai/handlers/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import logging
import warnings
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional, Union

from monai.config import IgniteInfo
from monai.utils import min_version, optional_import
Expand Down Expand Up @@ -62,6 +62,9 @@ class CheckpointSaver:
typically, it's used to resume training and compare current metric with previous N values.
key_metric_greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise,
save the the first equally scored model. default to `False`.
key_metric_score_sign: sign of the score, available value: `[1.0, -1.0, 1, -1]`. For error-like metrics,
Comment thread
Nic-Ma marked this conversation as resolved.
e.g. smaller is better, a negative score sign should be used (objects with larger score are retained).
default to `1.0`.
epoch_level: save checkpoint during training for every N epochs or every N iterations.
`True` is epoch level, `False` is iteration level.
save_interval: save checkpoint every N epochs, default is 0 to save no checkpoint.
Expand Down Expand Up @@ -93,6 +96,7 @@ def __init__(
key_metric_filename: Optional[str] = None,
key_metric_save_state: bool = False,
key_metric_greater_or_equal: bool = False,
key_metric_score_sign: Union[float, int] = 1.0,
epoch_level: bool = True,
save_interval: int = 0,
n_saved: Optional[int] = None,
Expand Down Expand Up @@ -155,7 +159,10 @@ def _score_func(engine: Engine):
raise ValueError(
f"Incompatible values: save_key_metric=True and key_metric_name={key_metric_name}."
)
return round(engine.state.metrics[metric_name], 4)
if key_metric_score_sign not in (1.0, -1.0, 1, -1):
raise ValueError("available value of key_metric_score_sign: `[1.0, -1.0, 1, -1]`.")

return key_metric_score_sign * engine.state.metrics[metric_name]

if key_metric_filename is not None and key_metric_n_saved > 1:
raise ValueError("if using fixed filename to save the best metric model, we should only save 1 model.")
Expand Down
23 changes: 16 additions & 7 deletions tests/test_handler_checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
None,
False,
False,
1.0,
True,
0,
None,
Expand All @@ -46,6 +47,7 @@
None,
False,
True,
1,
False,
0,
None,
Expand All @@ -61,6 +63,7 @@
None,
False,
True,
1.0,
True,
2,
2,
Expand All @@ -76,6 +79,7 @@
None,
False,
False,
1.0,
False,
10,
2,
Expand All @@ -91,18 +95,19 @@
None,
False,
False,
1.0,
True,
0,
None,
["test_checkpoint_final_iteration=40.pt"],
True,
]

TEST_CASE_6 = [True, "final_model.pt", False, None, 1, None, False, False, True, 0, None, ["final_model.pt"]]
TEST_CASE_6 = [True, "final_model.pt", False, None, 1, None, False, False, 1.0, True, 0, None, ["final_model.pt"]]

TEST_CASE_7 = [False, None, True, "val_loss", 1, "model.pt", False, False, True, 0, None, ["model.pt"]]
TEST_CASE_7 = [False, None, True, "val_loss", 1, "model.pt", False, False, 1.0, True, 0, None, ["model.pt"]]

TEST_CASE_8 = [False, None, True, "val_loss", 1, "model.pt", False, True, True, 0, None, ["model.pt"]]
TEST_CASE_8 = [False, None, True, "val_loss", 1, "model.pt", False, True, 1.0, True, 0, None, ["model.pt"]]


class TestHandlerCheckpointSaver(unittest.TestCase):
Expand All @@ -128,6 +133,7 @@ def test_file(
key_metric_filename,
key_metric_save_state,
key_metric_greater_or_equal,
key_metric_score_sign,
epoch_level,
save_interval,
n_saved,
Expand Down Expand Up @@ -162,6 +168,7 @@ def _train_func(engine, batch):
key_metric_filename,
key_metric_save_state,
key_metric_greater_or_equal,
key_metric_score_sign,
epoch_level,
save_interval,
n_saved,
Expand Down Expand Up @@ -211,24 +218,26 @@ def _train_func(engine, batch):
key_metric_name="val_loss",
key_metric_n_saved=2,
key_metric_save_state=True,
key_metric_score_sign=-1,
).attach(engine)
engine.run(range(3), max_epochs=2)
engine.run(range(3), max_epochs=3)

saver = CheckpointSaver(
save_dir=tempdir,
save_dict={"net": net},
save_key_metric=True,
key_metric_name="val_loss",
key_metric_n_saved=2,
key_metric_score_sign=-1,
)
engine = Engine(_train_func)
CheckpointLoader(os.path.join(tempdir, "net_key_metric=6.pt"), {"checkpointer": saver}).attach(engine)
CheckpointLoader(os.path.join(tempdir, "net_key_metric=-6.pt"), {"checkpointer": saver}).attach(engine)
engine.run(range(1), max_epochs=1)

resumed = saver._key_metric_checkpoint._saved
for i in range(2):
self.assertEqual(resumed[i].priority, 3 * (i + 1))
self.assertEqual(resumed[i].filename, f"net_key_metric={3 * (i + 1)}.pt")
self.assertEqual(resumed[1 - i].priority, -3 * (i + 1))
self.assertEqual(resumed[1 - i].filename, f"net_key_metric=-{3 * (i + 1)}.pt")


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions tests/test_integration_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def _forward_completed(self, engine):
"val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"]))
},
additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
metric_cmp_fn=lambda cur, prev: cur >= prev, # if greater or equal, treat as new best metric
val_handlers=val_handlers,
amp=True if amp else False,
)
Expand Down