diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py index 0c65b8cd4b..9c67992b36 100644 --- a/monai/handlers/checkpoint_saver.py +++ b/monai/handlers/checkpoint_saver.py @@ -60,6 +60,8 @@ class CheckpointSaver: if `True`, then will save an object in the checkpoint file with key `checkpointer` to be consistent with ignite: https://github.com/pytorch/ignite/blob/master/ignite/handlers/checkpoint.py#L99. typically, it's used to resume training and compare current metric with previous N values. + key_metric_greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise, + save the the first equally scored model. default to `False`. epoch_level: save checkpoint during training for every N epochs or every N iterations. `True` is epoch level, `False` is iteration level. save_interval: save checkpoint every N epochs, default is 0 to save no checkpoint. @@ -90,6 +92,7 @@ def __init__( key_metric_n_saved: int = 1, key_metric_filename: Optional[str] = None, key_metric_save_state: bool = False, + key_metric_greater_or_equal: bool = False, epoch_level: bool = True, save_interval: int = 0, n_saved: Optional[int] = None, @@ -163,6 +166,7 @@ def _score_func(engine: Engine): score_name="key_metric", n_saved=key_metric_n_saved, include_self=key_metric_save_state, + greater_or_equal=key_metric_greater_or_equal, ) if save_interval > 0: diff --git a/tests/test_handler_checkpoint_saver.py b/tests/test_handler_checkpoint_saver.py index 5c2b750a57..14474054df 100644 --- a/tests/test_handler_checkpoint_saver.py +++ b/tests/test_handler_checkpoint_saver.py @@ -22,7 +22,20 @@ from monai.handlers import CheckpointLoader, CheckpointSaver -TEST_CASE_1 = [True, None, False, None, 1, None, False, True, 0, None, ["test_checkpoint_final_iteration=40.pt"]] +TEST_CASE_1 = [ + True, + None, + False, + None, + 1, + None, + False, + False, + True, + 0, + None, + ["test_checkpoint_final_iteration=40.pt"], +] TEST_CASE_2 = [ False, @@ -33,6 +46,7 @@ None, False, True, + False, 0, None, ["test_checkpoint_key_metric=32.pt", "test_checkpoint_key_metric=40.pt"], @@ -47,6 +61,7 @@ None, False, True, + True, 2, 2, ["test_checkpoint_epoch=2.pt", "test_checkpoint_epoch=4.pt"], @@ -61,20 +76,48 @@ None, False, False, + False, 10, 2, ["test_checkpoint_iteration=30.pt", "test_checkpoint_iteration=40.pt"], ] -TEST_CASE_5 = [True, None, False, None, 1, None, False, True, 0, None, ["test_checkpoint_final_iteration=40.pt"], True] +TEST_CASE_5 = [ + True, + None, + False, + None, + 1, + None, + False, + False, + True, + 0, + None, + ["test_checkpoint_final_iteration=40.pt"], + True, +] + +TEST_CASE_6 = [True, "final_model.pt", False, None, 1, None, False, False, True, 0, None, ["final_model.pt"]] -TEST_CASE_6 = [True, "final_model.pt", False, None, 1, None, False, True, 0, None, ["final_model.pt"]] +TEST_CASE_7 = [False, None, True, "val_loss", 1, "model.pt", False, False, True, 0, None, ["model.pt"]] -TEST_CASE_7 = [False, None, True, "val_loss", 1, "model.pt", False, True, 0, None, ["model.pt"]] +TEST_CASE_8 = [False, None, True, "val_loss", 1, "model.pt", False, True, True, 0, None, ["model.pt"]] class TestHandlerCheckpointSaver(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + @parameterized.expand( + [ + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + ] + ) def test_file( self, save_final, @@ -84,6 +127,7 @@ def test_file( key_metric_n_saved, key_metric_filename, key_metric_save_state, + key_metric_greater_or_equal, epoch_level, save_interval, n_saved, @@ -117,6 +161,7 @@ def _train_func(engine, batch): key_metric_n_saved, key_metric_filename, key_metric_save_state, + key_metric_greater_or_equal, epoch_level, save_interval, n_saved,