diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index ffc59e7732..0b1cc02799 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -10,7 +10,8 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Optional +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional import torch @@ -46,12 +47,21 @@ class CheckpointLoader: first load the module to CPU and then copy each parameter to where it was saved, which would result in all processes on the same machine using the same set of devices. - strict: whether to strictly enforce that the keys in `state_dict` match the keys - returned by `torch.nn.Module.state_dict` function. default to `True`. + strict: whether to strictly enforce that the keys and data shape in the `state_dict` of every item + of `load_dict` match the `state_dict` of the corresponding items of checkpoint, default to `True`. strict_shape: whether to enforce the data shape of the matched layers in the checkpoint, - `if `False`, it will skip the layers that have different data shape with checkpoint content. - This can be useful advanced feature for transfer learning. users should totally - understand which layers will have different shape. default to `True`. + `if `False`, it will skip the layers that have different data shape with checkpoint content, + and ignore the `strict` arg. this can be useful advanced feature for transfer learning. + users should totally understand which layers will have different shape. default to `True`. + + Note: if `strict_shape=False`, will only load checkpoint for `torch.nn.Module` and skip other + items in the `load_dict`. For example, if the shape of some layers in current model can't + match the checkpoint, the `parameter_group` of current optimizer may also can't match the + checkpoint, so skip loading checkpoint for optimizer. + + For more details about loading checkpoint, please refer to: + https://github.com/pytorch/ignite/blob/v0.4.5/ignite/handlers/checkpoint.py#L499. + https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/modules/module.py#L1354. """ @@ -73,6 +83,9 @@ def __init__( self.load_dict = load_dict self._name = name self.map_location = map_location + if strict and not strict_shape: + warnings.warn("as `strict_shape` is already False, change `strict` to False.") + strict = False self.strict = strict self.strict_shape = strict_shape @@ -92,15 +105,22 @@ def __call__(self, engine: Engine) -> None: """ checkpoint = torch.load(self.load_path, map_location=self.map_location) - if not self.strict_shape: - k, _ = list(self.load_dict.items())[0] - # single object and checkpoint is directly a state_dict - if len(self.load_dict) == 1 and k not in checkpoint: - checkpoint = {k: checkpoint} + k, _ = list(self.load_dict.items())[0] + # single object and checkpoint is directly a state_dict + if len(self.load_dict) == 1 and k not in checkpoint: + checkpoint = {k: checkpoint} - # skip items that don't match data shape + if not self.strict_shape: + pop_items: List[str] = [] for k, obj in self.load_dict.items(): - checkpoint[k] = copy_model_state(obj, checkpoint, inplace=False)[0] + if isinstance(obj, torch.nn.Module): + # skip items that don't match key name or data shape + checkpoint[k] = copy_model_state(obj, checkpoint, inplace=False)[0] + else: + warnings.warn("`strict_shape` is False, load checkpoint for model, skip others in `load_dict`.") + pop_items.append(k) + for i in pop_items: + self.load_dict.pop(i) # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint prior_max_epochs = engine.state.max_epochs diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index a69193c98c..81a3cdc96d 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -153,22 +153,34 @@ def test_strict_shape(self): data1["0.weight"] = torch.tensor([1, 2, 3, 4, 5]) data1["new"] = torch.tensor(0.1) net1.load_state_dict(data1, strict=False) + opt1 = optim.SGD(net1.parameters(), lr=0.02) net2 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) data2 = net2.state_dict() data2["0.weight"] = torch.tensor([0.2]) data2["1.weight"] = torch.tensor([0.3]) net2.load_state_dict(data2) + opt2 = optim.SGD(net2.parameters(), lr=0.02) with tempfile.TemporaryDirectory() as tempdir: engine = Engine(lambda e, b: None) - CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine) + CheckpointSaver(save_dir=tempdir, save_dict={"net": net1, "opt": opt1}, save_final=True).attach(engine) engine.run([0] * 8, max_epochs=5) - path = tempdir + "/net_final_iteration=40.pt" + path = tempdir + "/checkpoint_final_iteration=40.pt" engine = Engine(lambda e, b: None) - CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=False, strict_shape=False).attach(engine) + CheckpointLoader( + load_path=path, + # expect to print a warning because it loads not only `net` but also `opt` with `strict_shape=False` + load_dict={"net": net2, "opt": opt2}, + strict=False, + strict_shape=False, + ).attach(engine) engine.run([0] * 8, max_epochs=1) torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.2])) + # test whether `opt2` had been skipped when loading with `strict_shape=False`, + # it should have 2 items in `params`(0.weight and 1.weight) while the checkpoint has 1 item(0.weight) + self.assertEqual(len(opt1.state_dict()["param_groups"][0]["params"]), 1) + self.assertEqual(len(opt2.state_dict()["param_groups"][0]["params"]), 2) if __name__ == "__main__":