From 5dcb672d2ea6a9e3534829bf6e2d2557e94743ff Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 2 Jul 2021 11:50:09 +0800 Subject: [PATCH 01/10] [DLMED] fix strict_shape Signed-off-by: Nic Ma --- monai/handlers/checkpoint_loader.py | 26 +++++++++++++++++-------- tests/test_handler_checkpoint_loader.py | 13 ++++++++++--- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index ffc59e7732..347ea4a48e 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -10,6 +10,7 @@ # limitations under the License. import logging +import warnings from typing import TYPE_CHECKING, Dict, Optional import torch @@ -53,6 +54,9 @@ class CheckpointLoader: This can be useful advanced feature for transfer learning. users should totally understand which layers will have different shape. default to `True`. + Note: if `strict_shape=False`, will only load checkpoint for `torch.nn.Module` and skip other + items in the `load_dict`. + """ def __init__( @@ -92,15 +96,21 @@ def __call__(self, engine: Engine) -> None: """ checkpoint = torch.load(self.load_path, map_location=self.map_location) + k, _ = list(self.load_dict.items())[0] + # single object and checkpoint is directly a state_dict + if len(self.load_dict) == 1 and k not in checkpoint: + checkpoint = {k: checkpoint} + + # skip items that don't match data shape if not self.strict_shape: - k, _ = list(self.load_dict.items())[0] - # single object and checkpoint is directly a state_dict - if len(self.load_dict) == 1 and k not in checkpoint: - checkpoint = {k: checkpoint} - - # skip items that don't match data shape - for k, obj in self.load_dict.items(): - checkpoint[k] = copy_model_state(obj, checkpoint, inplace=False)[0] + items = self.load_dict.keys() + for k in items: + obj = self.load_dict[k] + if isinstance(obj, (torch.nn.Module)): + checkpoint[k] = copy_model_state(obj, checkpoint, inplace=False)[0] + else: + warnings.warn("if `strict_shape` is False, only load checkpoint for model and skip others.") + self.load_dict.pop(k) # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint prior_max_epochs = engine.state.max_epochs diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index a69193c98c..6aa8637eec 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -153,20 +153,27 @@ def test_strict_shape(self): data1["0.weight"] = torch.tensor([1, 2, 3, 4, 5]) data1["new"] = torch.tensor(0.1) net1.load_state_dict(data1, strict=False) + opt1 = optim.SGD(net1.parameters(), lr=0.02) net2 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) data2 = net2.state_dict() data2["0.weight"] = torch.tensor([0.2]) data2["1.weight"] = torch.tensor([0.3]) net2.load_state_dict(data2) + opt2 = optim.SGD(net2.parameters(), lr=0.02) with tempfile.TemporaryDirectory() as tempdir: engine = Engine(lambda e, b: None) - CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine) + CheckpointSaver(save_dir=tempdir, save_dict={"net": net1, "opt": opt1}, save_final=True).attach(engine) engine.run([0] * 8, max_epochs=5) - path = tempdir + "/net_final_iteration=40.pt" + path = tempdir + "/checkpoint_final_iteration=40.pt" engine = Engine(lambda e, b: None) - CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=False, strict_shape=False).attach(engine) + CheckpointLoader( + load_path=path, + load_dict={"net": net2, "opt": opt2}, + strict=False, + strict_shape=False, + ).attach(engine) engine.run([0] * 8, max_epochs=1) torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.2])) From 01b64251c2a41c20cf158ccbe8a8471ba2965753 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 2 Jul 2021 13:42:04 +0800 Subject: [PATCH 02/10] [DLMED] add warning message Signed-off-by: Nic Ma --- monai/handlers/checkpoint_loader.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index 347ea4a48e..d109e346aa 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -11,7 +11,7 @@ import logging import warnings -from typing import TYPE_CHECKING, Dict, Optional +from typing import List, TYPE_CHECKING, Dict, Optional import torch @@ -55,7 +55,9 @@ class CheckpointLoader: understand which layers will have different shape. default to `True`. Note: if `strict_shape=False`, will only load checkpoint for `torch.nn.Module` and skip other - items in the `load_dict`. + items in the `load_dict`. For example, if the shape of some layers in current model can't + match the checkpoint, the `parameter_group` of current optimizer may also can't match the + checkpoint, so skip loading checkpoint for optimizer. """ @@ -103,14 +105,15 @@ def __call__(self, engine: Engine) -> None: # skip items that don't match data shape if not self.strict_shape: - items = self.load_dict.keys() - for k in items: - obj = self.load_dict[k] + pop_items: List[str] = [] + for k, obj in self.load_dict.items(): if isinstance(obj, (torch.nn.Module)): checkpoint[k] = copy_model_state(obj, checkpoint, inplace=False)[0] else: - warnings.warn("if `strict_shape` is False, only load checkpoint for model and skip others.") - self.load_dict.pop(k) + warnings.warn("`strict_shape` is False, load checkpoint for model, skip others in `load_dict`.") + pop_items.append(k) + for i in pop_items: + self.load_dict.pop(i) # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint prior_max_epochs = engine.state.max_epochs From fdaa2dcb8acaf9f0f09a6e0a44d723a409469045 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Fri, 2 Jul 2021 05:46:23 +0000 Subject: [PATCH 03/10] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/handlers/checkpoint_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index d109e346aa..8d6ea8cf91 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -11,7 +11,7 @@ import logging import warnings -from typing import List, TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, List, Optional import torch From 86f5c97f4335e6187761b03b0698e41624f4d754 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 2 Jul 2021 18:09:37 +0800 Subject: [PATCH 04/10] [DLMED] update according to comments Signed-off-by: Nic Ma --- tests/test_handler_checkpoint_loader.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index 6aa8637eec..ccbcee290a 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -13,6 +13,7 @@ import sys import tempfile import unittest +from copy import deepcopy import torch import torch.optim as optim @@ -161,6 +162,7 @@ def test_strict_shape(self): data2["1.weight"] = torch.tensor([0.3]) net2.load_state_dict(data2) opt2 = optim.SGD(net2.parameters(), lr=0.02) + opt2_backup = deepcopy(opt2) with tempfile.TemporaryDirectory() as tempdir: engine = Engine(lambda e, b: None) @@ -170,12 +172,15 @@ def test_strict_shape(self): engine = Engine(lambda e, b: None) CheckpointLoader( load_path=path, + # expect to print a warning because it loads not only `net` but also `opt` with `strict_shape=False` load_dict={"net": net2, "opt": opt2}, strict=False, strict_shape=False, ).attach(engine) engine.run([0] * 8, max_epochs=1) torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.2])) + # test whether `op2` had been skipped when loading with `strict_shape=False` + self.assertDictEqual(opt2.state_dict(), opt2_backup.state_dict()) if __name__ == "__main__": From 4b0b212d0a55258cc54d92cc2f4980de4da36344 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 2 Jul 2021 18:42:31 +0800 Subject: [PATCH 05/10] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/handlers/checkpoint_loader.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index 8d6ea8cf91..11a0f99a42 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -47,7 +47,7 @@ class CheckpointLoader: first load the module to CPU and then copy each parameter to where it was saved, which would result in all processes on the same machine using the same set of devices. - strict: whether to strictly enforce that the keys in `state_dict` match the keys + strict: whether to strictly enforce that the keys and data shape in `state_dict` match the keys returned by `torch.nn.Module.state_dict` function. default to `True`. strict_shape: whether to enforce the data shape of the matched layers in the checkpoint, `if `False`, it will skip the layers that have different data shape with checkpoint content. @@ -79,6 +79,9 @@ def __init__( self.load_dict = load_dict self._name = name self.map_location = map_location + if strict and not strict_shape: + warnings.warn("as `strict_shape` is already False, change `strict` to False.") + strict = False self.strict = strict self.strict_shape = strict_shape @@ -103,11 +106,11 @@ def __call__(self, engine: Engine) -> None: if len(self.load_dict) == 1 and k not in checkpoint: checkpoint = {k: checkpoint} - # skip items that don't match data shape if not self.strict_shape: pop_items: List[str] = [] for k, obj in self.load_dict.items(): if isinstance(obj, (torch.nn.Module)): + # skip items that don't match key name or data shape checkpoint[k] = copy_model_state(obj, checkpoint, inplace=False)[0] else: warnings.warn("`strict_shape` is False, load checkpoint for model, skip others in `load_dict`.") From 845cdddc65570c3bda915e6f37d47af682a8aac0 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 2 Jul 2021 18:49:41 +0800 Subject: [PATCH 06/10] [DLMED] add links Signed-off-by: Nic Ma --- monai/handlers/checkpoint_loader.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index 11a0f99a42..0990814319 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -59,6 +59,10 @@ class CheckpointLoader: match the checkpoint, the `parameter_group` of current optimizer may also can't match the checkpoint, so skip loading checkpoint for optimizer. + For more details about loading checkpoint, please refer to: + https://github.com/pytorch/ignite/blob/v0.4.5/ignite/handlers/checkpoint.py#L499. + https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/modules/module.py#L1354. + """ def __init__( From 004e442308885e9642ee853340ce98cb9980adda Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 2 Jul 2021 21:52:48 +0800 Subject: [PATCH 07/10] [DLMED] Enhance doc-strings Signed-off-by: Nic Ma --- monai/handlers/checkpoint_loader.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index 0990814319..0b1cc02799 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -47,12 +47,12 @@ class CheckpointLoader: first load the module to CPU and then copy each parameter to where it was saved, which would result in all processes on the same machine using the same set of devices. - strict: whether to strictly enforce that the keys and data shape in `state_dict` match the keys - returned by `torch.nn.Module.state_dict` function. default to `True`. + strict: whether to strictly enforce that the keys and data shape in the `state_dict` of every item + of `load_dict` match the `state_dict` of the corresponding items of checkpoint, default to `True`. strict_shape: whether to enforce the data shape of the matched layers in the checkpoint, - `if `False`, it will skip the layers that have different data shape with checkpoint content. - This can be useful advanced feature for transfer learning. users should totally - understand which layers will have different shape. default to `True`. + `if `False`, it will skip the layers that have different data shape with checkpoint content, + and ignore the `strict` arg. this can be useful advanced feature for transfer learning. + users should totally understand which layers will have different shape. default to `True`. Note: if `strict_shape=False`, will only load checkpoint for `torch.nn.Module` and skip other items in the `load_dict`. For example, if the shape of some layers in current model can't @@ -113,7 +113,7 @@ def __call__(self, engine: Engine) -> None: if not self.strict_shape: pop_items: List[str] = [] for k, obj in self.load_dict.items(): - if isinstance(obj, (torch.nn.Module)): + if isinstance(obj, torch.nn.Module): # skip items that don't match key name or data shape checkpoint[k] = copy_model_state(obj, checkpoint, inplace=False)[0] else: From f8b81d4030e2fe369315964c25e577e00b283abc Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 2 Jul 2021 23:19:07 +0800 Subject: [PATCH 08/10] [DLMED] enhance CI tests Signed-off-by: Nic Ma --- tests/test_handler_checkpoint_loader.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index ccbcee290a..9a1908ac5f 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -162,7 +162,6 @@ def test_strict_shape(self): data2["1.weight"] = torch.tensor([0.3]) net2.load_state_dict(data2) opt2 = optim.SGD(net2.parameters(), lr=0.02) - opt2_backup = deepcopy(opt2) with tempfile.TemporaryDirectory() as tempdir: engine = Engine(lambda e, b: None) @@ -179,8 +178,11 @@ def test_strict_shape(self): ).attach(engine) engine.run([0] * 8, max_epochs=1) torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.2])) - # test whether `op2` had been skipped when loading with `strict_shape=False` - self.assertDictEqual(opt2.state_dict(), opt2_backup.state_dict()) + # test whether `opt2` had been skipped when loading with `strict_shape=False`, + # it should have 2 items in `params`(0.weight and 1.weight) while the checkpoint has 1 item(0.weight) + self.assertEqual(len(opt1.state_dict()["param_groups"][0]["params"]), 1) + self.assertEqual(len(opt2.state_dict()["param_groups"][0]["params"]), 2) + if __name__ == "__main__": From d7d209982868d97ece7f451b659c7b6028a03253 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 2 Jul 2021 23:24:17 +0800 Subject: [PATCH 09/10] [DLMED] fix flake8 Signed-off-by: Nic Ma --- tests/test_handler_checkpoint_loader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index 9a1908ac5f..e95012bd00 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -13,7 +13,6 @@ import sys import tempfile import unittest -from copy import deepcopy import torch import torch.optim as optim From 0a84cb99f3ec68b37f829441972a1a976a99724f Mon Sep 17 00:00:00 2001 From: monai-bot Date: Fri, 2 Jul 2021 15:28:22 +0000 Subject: [PATCH 10/10] [MONAI] python code formatting Signed-off-by: monai-bot --- tests/test_handler_checkpoint_loader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index e95012bd00..81a3cdc96d 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -183,6 +183,5 @@ def test_strict_shape(self): self.assertEqual(len(opt2.state_dict()["param_groups"][0]["params"]), 2) - if __name__ == "__main__": unittest.main()