From e3f6b42a3899985c88d727a88107642b7cc2c130 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 30 Jun 2021 19:20:51 +0800 Subject: [PATCH 1/3] [DLMED] add support to compare metrics Signed-off-by: Nic Ma --- monai/engines/__init__.py | 1 + monai/engines/evaluator.py | 14 +++++++++++++- monai/engines/trainer.py | 16 +++++++++++++++- monai/engines/utils.py | 13 +++++++++++++ monai/engines/workflow.py | 8 ++++++-- monai/handlers/checkpoint_saver.py | 8 +++++++- tests/test_handler_checkpoint_saver.py | 23 ++++++++++++++++------- tests/test_integration_workflows.py | 1 + 8 files changed, 72 insertions(+), 12 deletions(-) diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index 36719ae61c..89ebc8b47c 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -16,6 +16,7 @@ GanKeys, IterationEvents, default_make_latent, + default_metric_cmp_fn, default_prepare_batch, engine_apply_transform, get_devices_spec, diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 6da75cb951..190ad98f7c 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -15,7 +15,7 @@ from torch.utils.data import DataLoader from monai.config import IgniteInfo -from monai.engines.utils import IterationEvents, default_prepare_batch +from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.networks.utils import eval_mode, train_mode @@ -53,6 +53,8 @@ class Evaluator(Workflow): engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the checkpoint into files. additional_metrics: more Ignite metrics that also attach to Ignite Engine. + metric_cmp_fn: function to compare current key metric with previous best key metric value. + it must accept 2 args (current_metric, previous_best) and return a bool result. default to `greater than`. val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision evaluation, default is False. @@ -79,6 +81,7 @@ def __init__( postprocessing: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, + metric_cmp_fn: Callable = default_metric_cmp_fn, val_handlers: Optional[Sequence] = None, amp: bool = False, mode: Union[ForwardMode, str] = ForwardMode.EVAL, @@ -97,6 +100,7 @@ def __init__( postprocessing=postprocessing, key_metric=key_val_metric, additional_metrics=additional_metrics, + metric_cmp_fn=metric_cmp_fn, handlers=val_handlers, amp=amp, event_names=event_names, @@ -150,6 +154,8 @@ class SupervisedEvaluator(Evaluator): engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the checkpoint into files. additional_metrics: more Ignite metrics that also attach to Ignite Engine. + metric_cmp_fn: function to compare current key metric with previous best key metric value. + it must accept 2 args (current_metric, previous_best) and return a bool result. default to `greater than`. val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision evaluation, default is False. @@ -178,6 +184,7 @@ def __init__( postprocessing: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, + metric_cmp_fn: Callable = default_metric_cmp_fn, val_handlers: Optional[Sequence] = None, amp: bool = False, mode: Union[ForwardMode, str] = ForwardMode.EVAL, @@ -195,6 +202,7 @@ def __init__( postprocessing=postprocessing, key_val_metric=key_val_metric, additional_metrics=additional_metrics, + metric_cmp_fn=metric_cmp_fn, val_handlers=val_handlers, amp=amp, mode=mode, @@ -272,6 +280,8 @@ class EnsembleEvaluator(Evaluator): engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the checkpoint into files. additional_metrics: more Ignite metrics that also attach to Ignite Engine. + metric_cmp_fn: function to compare current key metric with previous best key metric value. + it must accept 2 args (current_metric, previous_best) and return a bool result. default to `greater than`. val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision evaluation, default is False. @@ -301,6 +311,7 @@ def __init__( postprocessing: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, + metric_cmp_fn: Callable = default_metric_cmp_fn, val_handlers: Optional[Sequence] = None, amp: bool = False, mode: Union[ForwardMode, str] = ForwardMode.EVAL, @@ -318,6 +329,7 @@ def __init__( postprocessing=postprocessing, key_val_metric=key_val_metric, additional_metrics=additional_metrics, + metric_cmp_fn=metric_cmp_fn, val_handlers=val_handlers, amp=amp, mode=mode, diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index d8b4ec9a26..c77a20d7c2 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -16,7 +16,13 @@ from torch.utils.data import DataLoader from monai.config import IgniteInfo -from monai.engines.utils import GanKeys, IterationEvents, default_make_latent, default_prepare_batch +from monai.engines.utils import ( + GanKeys, + IterationEvents, + default_make_latent, + default_metric_cmp_fn, + default_prepare_batch, +) from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform @@ -79,6 +85,8 @@ class SupervisedTrainer(Trainer): engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the checkpoint into files. additional_metrics: more Ignite metrics that also attach to Ignite Engine. + metric_cmp_fn: function to compare current key metric with previous best key metric value. + it must accept 2 args (current_metric, previous_best) and return a bool result. default to `greater than`. train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision training, default is False. @@ -108,6 +116,7 @@ def __init__( postprocessing: Optional[Transform] = None, key_train_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, + metric_cmp_fn: Callable = default_metric_cmp_fn, train_handlers: Optional[Sequence] = None, amp: bool = False, event_names: Optional[List[Union[str, EventEnum]]] = None, @@ -125,6 +134,7 @@ def __init__( postprocessing=postprocessing, key_metric=key_train_metric, additional_metrics=additional_metrics, + metric_cmp_fn=metric_cmp_fn, handlers=train_handlers, amp=amp, event_names=event_names, @@ -232,6 +242,8 @@ class GanTrainer(Trainer): engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the checkpoint into files. additional_metrics: more Ignite metrics that also attach to Ignite Engine. + metric_cmp_fn: function to compare current key metric with previous best key metric value. + it must accept 2 args (current_metric, previous_best) and return a bool result. default to `greater than`. train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. decollate: whether to decollate the batch-first data to a list of data after model computation, @@ -264,6 +276,7 @@ def __init__( postprocessing: Optional[Transform] = None, key_train_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, + metric_cmp_fn: Callable = default_metric_cmp_fn, train_handlers: Optional[Sequence] = None, decollate: bool = True, ): @@ -281,6 +294,7 @@ def __init__( iteration_update=iteration_update, key_metric=key_train_metric, additional_metrics=additional_metrics, + metric_cmp_fn=metric_cmp_fn, handlers=train_handlers, postprocessing=postprocessing, decollate=decollate, diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 26038ab8a5..d2afa5c37e 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -30,6 +30,7 @@ "default_prepare_batch", "default_make_latent", "engine_apply_transform", + "default_metric_cmp_fn", ] @@ -158,3 +159,15 @@ def engine_apply_transform(batch: Any, output: Any, transform: Callable[..., Dic output = apply_transform(transform, output) return batch, output + + +def default_metric_cmp_fn(current_metric: float, prev_best: float) -> bool: + """ + The default function to compare metric values between current metric and previous best metric. + + Args: + current_metric: metric value of current round computation. + prev_best: the best metric value of previous rounds to compare with. + + """ + return current_metric > prev_best diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index f39c720bcc..945e1f4afa 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -19,7 +19,7 @@ from monai.config import IgniteInfo from monai.data import decollate_batch, rep_scalar_to_batch -from monai.engines.utils import IterationEvents, default_prepare_batch +from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch from monai.utils import ensure_tuple, min_version, optional_import from .utils import engine_apply_transform @@ -63,6 +63,8 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona engine.state.metrics when epoch completed. key_metric is the main metric to compare and save the checkpoint into files. additional_metrics: more Ignite metrics that also attach to Ignite Engine. + metric_cmp_fn: function to compare current key metric with previous best key metric value. + it must accept 2 args (current_metric, previous_best) and return a bool result. default to `greater than`. handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision training or inference, default is False. @@ -94,6 +96,7 @@ def __init__( postprocessing: Optional[Callable] = None, key_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, + metric_cmp_fn: Callable = default_metric_cmp_fn, handlers: Optional[Sequence] = None, amp: bool = False, event_names: Optional[List[Union[str, EventEnum]]] = None, @@ -142,6 +145,7 @@ def set_sampler_epoch(engine: Engine): self.data_loader = data_loader self.non_blocking = non_blocking self.prepare_batch = prepare_batch + self.metric_cmp_fn = metric_cmp_fn self.amp = amp if event_names is None: @@ -214,7 +218,7 @@ def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None): def _compare_metrics(engine: Engine) -> None: if engine.state.key_metric_name is not None: current_val_metric = engine.state.metrics[engine.state.key_metric_name] - if current_val_metric > engine.state.best_metric: + if self.metric_cmp_fn(current_val_metric, engine.state.best_metric): self.logger.info(f"Got new best metric of {engine.state.key_metric_name}: {current_val_metric}") engine.state.best_metric = current_val_metric engine.state.best_metric_epoch = engine.state.epoch diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index bd725e1853..e27a287db3 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -62,6 +62,8 @@ class CheckpointSaver: typically, it's used to resume training and compare current metric with previous N values. key_metric_greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise, save the the first equally scored model. default to `False`. + key_metric_score_sign: sign of the score: `1.0` or `-1.0`. For error-like metrics, e.g. smaller is better, + a negative score sign should be used (objects with larger score are retained). default to `1.0`. epoch_level: save checkpoint during training for every N epochs or every N iterations. `True` is epoch level, `False` is iteration level. save_interval: save checkpoint every N epochs, default is 0 to save no checkpoint. @@ -93,6 +95,7 @@ def __init__( key_metric_filename: Optional[str] = None, key_metric_save_state: bool = False, key_metric_greater_or_equal: bool = False, + key_metric_score_sign: float = 1.0, epoch_level: bool = True, save_interval: int = 0, n_saved: Optional[int] = None, @@ -155,7 +158,10 @@ def _score_func(engine: Engine): raise ValueError( f"Incompatible values: save_key_metric=True and key_metric_name={key_metric_name}." ) - return round(engine.state.metrics[metric_name], 4) + if key_metric_score_sign not in (1.0, -1.0): + raise ValueError("Argument score_sign should be 1.0 or -1.0.") + + return key_metric_score_sign * engine.state.metrics[metric_name] if key_metric_filename is not None and key_metric_n_saved > 1: raise ValueError("if using fixed filename to save the best metric model, we should only save 1 model.") diff --git a/tests/test_handler_checkpoint_saver.py b/tests/test_handler_checkpoint_saver.py index 14474054df..c8bf3fbb6f 100644 --- a/tests/test_handler_checkpoint_saver.py +++ b/tests/test_handler_checkpoint_saver.py @@ -31,6 +31,7 @@ None, False, False, + 1.0, True, 0, None, @@ -46,6 +47,7 @@ None, False, True, + 1, False, 0, None, @@ -61,6 +63,7 @@ None, False, True, + 1.0, True, 2, 2, @@ -76,6 +79,7 @@ None, False, False, + 1.0, False, 10, 2, @@ -91,6 +95,7 @@ None, False, False, + 1.0, True, 0, None, @@ -98,11 +103,11 @@ True, ] -TEST_CASE_6 = [True, "final_model.pt", False, None, 1, None, False, False, True, 0, None, ["final_model.pt"]] +TEST_CASE_6 = [True, "final_model.pt", False, None, 1, None, False, False, 1.0, True, 0, None, ["final_model.pt"]] -TEST_CASE_7 = [False, None, True, "val_loss", 1, "model.pt", False, False, True, 0, None, ["model.pt"]] +TEST_CASE_7 = [False, None, True, "val_loss", 1, "model.pt", False, False, 1.0, True, 0, None, ["model.pt"]] -TEST_CASE_8 = [False, None, True, "val_loss", 1, "model.pt", False, True, True, 0, None, ["model.pt"]] +TEST_CASE_8 = [False, None, True, "val_loss", 1, "model.pt", False, True, 1.0, True, 0, None, ["model.pt"]] class TestHandlerCheckpointSaver(unittest.TestCase): @@ -128,6 +133,7 @@ def test_file( key_metric_filename, key_metric_save_state, key_metric_greater_or_equal, + key_metric_score_sign, epoch_level, save_interval, n_saved, @@ -162,6 +168,7 @@ def _train_func(engine, batch): key_metric_filename, key_metric_save_state, key_metric_greater_or_equal, + key_metric_score_sign, epoch_level, save_interval, n_saved, @@ -211,8 +218,9 @@ def _train_func(engine, batch): key_metric_name="val_loss", key_metric_n_saved=2, key_metric_save_state=True, + key_metric_score_sign=-1, ).attach(engine) - engine.run(range(3), max_epochs=2) + engine.run(range(3), max_epochs=3) saver = CheckpointSaver( save_dir=tempdir, @@ -220,15 +228,16 @@ def _train_func(engine, batch): save_key_metric=True, key_metric_name="val_loss", key_metric_n_saved=2, + key_metric_score_sign=-1, ) engine = Engine(_train_func) - CheckpointLoader(os.path.join(tempdir, "net_key_metric=6.pt"), {"checkpointer": saver}).attach(engine) + CheckpointLoader(os.path.join(tempdir, "net_key_metric=-6.pt"), {"checkpointer": saver}).attach(engine) engine.run(range(1), max_epochs=1) resumed = saver._key_metric_checkpoint._saved for i in range(2): - self.assertEqual(resumed[i].priority, 3 * (i + 1)) - self.assertEqual(resumed[i].filename, f"net_key_metric={3 * (i + 1)}.pt") + self.assertEqual(resumed[1 - i].priority, -3 * (i + 1)) + self.assertEqual(resumed[1 - i].filename, f"net_key_metric=-{3 * (i + 1)}.pt") if __name__ == "__main__": diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 36681838d1..8c4f978beb 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -145,6 +145,7 @@ def _forward_completed(self, engine): "val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"])) }, additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))}, + metric_cmp_fn=lambda cur, prev: cur >= prev, # if greater or equal, treat as new best metric val_handlers=val_handlers, amp=True if amp else False, ) From fa4a53eb91ddc8e8dad44282ad969f212ff0b8a6 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 30 Jun 2021 22:45:26 +0800 Subject: [PATCH 2/3] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/engines/evaluator.py | 15 +++++++++------ monai/engines/trainer.py | 10 ++++++---- monai/engines/workflow.py | 5 +++-- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 190ad98f7c..62b922ba7b 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -53,8 +53,9 @@ class Evaluator(Workflow): engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the checkpoint into files. additional_metrics: more Ignite metrics that also attach to Ignite Engine. - metric_cmp_fn: function to compare current key metric with previous best key metric value. - it must accept 2 args (current_metric, previous_best) and return a bool result. default to `greater than`. + metric_cmp_fn: function to compare current key metric with previous best key metric value, + it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update + `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`. val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision evaluation, default is False. @@ -154,8 +155,9 @@ class SupervisedEvaluator(Evaluator): engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the checkpoint into files. additional_metrics: more Ignite metrics that also attach to Ignite Engine. - metric_cmp_fn: function to compare current key metric with previous best key metric value. - it must accept 2 args (current_metric, previous_best) and return a bool result. default to `greater than`. + metric_cmp_fn: function to compare current key metric with previous best key metric value, + it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update + `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`. val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision evaluation, default is False. @@ -280,8 +282,9 @@ class EnsembleEvaluator(Evaluator): engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the checkpoint into files. additional_metrics: more Ignite metrics that also attach to Ignite Engine. - metric_cmp_fn: function to compare current key metric with previous best key metric value. - it must accept 2 args (current_metric, previous_best) and return a bool result. default to `greater than`. + metric_cmp_fn: function to compare current key metric with previous best key metric value, + it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update + `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`. val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision evaluation, default is False. diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index c77a20d7c2..b504b57c73 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -85,8 +85,9 @@ class SupervisedTrainer(Trainer): engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the checkpoint into files. additional_metrics: more Ignite metrics that also attach to Ignite Engine. - metric_cmp_fn: function to compare current key metric with previous best key metric value. - it must accept 2 args (current_metric, previous_best) and return a bool result. default to `greater than`. + metric_cmp_fn: function to compare current key metric with previous best key metric value, + it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update + `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`. train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision training, default is False. @@ -242,8 +243,9 @@ class GanTrainer(Trainer): engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the checkpoint into files. additional_metrics: more Ignite metrics that also attach to Ignite Engine. - metric_cmp_fn: function to compare current key metric with previous best key metric value. - it must accept 2 args (current_metric, previous_best) and return a bool result. default to `greater than`. + metric_cmp_fn: function to compare current key metric with previous best key metric value, + it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update + `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`. train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. decollate: whether to decollate the batch-first data to a list of data after model computation, diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 945e1f4afa..8801592d73 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -63,8 +63,9 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona engine.state.metrics when epoch completed. key_metric is the main metric to compare and save the checkpoint into files. additional_metrics: more Ignite metrics that also attach to Ignite Engine. - metric_cmp_fn: function to compare current key metric with previous best key metric value. - it must accept 2 args (current_metric, previous_best) and return a bool result. default to `greater than`. + metric_cmp_fn: function to compare current key metric with previous best key metric value, + it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update + `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`. handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. amp: whether to enable auto-mixed-precision training or inference, default is False. From 13ea681dbf25dbb2e61209bed740d0e095fab6fd Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 1 Jul 2021 07:27:41 +0800 Subject: [PATCH 3/3] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/handlers/checkpoint_saver.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index e27a287db3..8a42eac6f0 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -11,7 +11,7 @@ import logging import warnings -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, Optional, Union from monai.config import IgniteInfo from monai.utils import min_version, optional_import @@ -62,8 +62,9 @@ class CheckpointSaver: typically, it's used to resume training and compare current metric with previous N values. key_metric_greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise, save the the first equally scored model. default to `False`. - key_metric_score_sign: sign of the score: `1.0` or `-1.0`. For error-like metrics, e.g. smaller is better, - a negative score sign should be used (objects with larger score are retained). default to `1.0`. + key_metric_score_sign: sign of the score, available value: `[1.0, -1.0, 1, -1]`. For error-like metrics, + e.g. smaller is better, a negative score sign should be used (objects with larger score are retained). + default to `1.0`. epoch_level: save checkpoint during training for every N epochs or every N iterations. `True` is epoch level, `False` is iteration level. save_interval: save checkpoint every N epochs, default is 0 to save no checkpoint. @@ -95,7 +96,7 @@ def __init__( key_metric_filename: Optional[str] = None, key_metric_save_state: bool = False, key_metric_greater_or_equal: bool = False, - key_metric_score_sign: float = 1.0, + key_metric_score_sign: Union[float, int] = 1.0, epoch_level: bool = True, save_interval: int = 0, n_saved: Optional[int] = None, @@ -158,8 +159,8 @@ def _score_func(engine: Engine): raise ValueError( f"Incompatible values: save_key_metric=True and key_metric_name={key_metric_name}." ) - if key_metric_score_sign not in (1.0, -1.0): - raise ValueError("Argument score_sign should be 1.0 or -1.0.") + if key_metric_score_sign not in (1.0, -1.0, 1, -1): + raise ValueError("available value of key_metric_score_sign: `[1.0, -1.0, 1, -1]`.") return key_metric_score_sign * engine.state.metrics[metric_name]