Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions monai/handlers/checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Dict, Optional
import warnings
from typing import TYPE_CHECKING, Dict, List, Optional

import torch

Expand Down Expand Up @@ -46,12 +47,21 @@ class CheckpointLoader:
first load the module to CPU and then copy each parameter to where it was
saved, which would result in all processes on the same machine using the
same set of devices.
strict: whether to strictly enforce that the keys in `state_dict` match the keys
returned by `torch.nn.Module.state_dict` function. default to `True`.
strict: whether to strictly enforce that the keys and data shape in the `state_dict` of every item
of `load_dict` match the `state_dict` of the corresponding items of checkpoint, default to `True`.
strict_shape: whether to enforce the data shape of the matched layers in the checkpoint,
`if `False`, it will skip the layers that have different data shape with checkpoint content.
This can be useful advanced feature for transfer learning. users should totally
understand which layers will have different shape. default to `True`.
`if `False`, it will skip the layers that have different data shape with checkpoint content,
and ignore the `strict` arg. this can be useful advanced feature for transfer learning.
users should totally understand which layers will have different shape. default to `True`.

Note: if `strict_shape=False`, will only load checkpoint for `torch.nn.Module` and skip other
Comment thread
Nic-Ma marked this conversation as resolved.
items in the `load_dict`. For example, if the shape of some layers in current model can't
match the checkpoint, the `parameter_group` of current optimizer may also can't match the
checkpoint, so skip loading checkpoint for optimizer.

For more details about loading checkpoint, please refer to:
https://github.com/pytorch/ignite/blob/v0.4.5/ignite/handlers/checkpoint.py#L499.
https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/modules/module.py#L1354.

"""

Expand All @@ -73,6 +83,9 @@ def __init__(
self.load_dict = load_dict
self._name = name
self.map_location = map_location
if strict and not strict_shape:
warnings.warn("as `strict_shape` is already False, change `strict` to False.")
strict = False
self.strict = strict
self.strict_shape = strict_shape

Expand All @@ -92,15 +105,22 @@ def __call__(self, engine: Engine) -> None:
"""
checkpoint = torch.load(self.load_path, map_location=self.map_location)

if not self.strict_shape:
k, _ = list(self.load_dict.items())[0]
# single object and checkpoint is directly a state_dict
if len(self.load_dict) == 1 and k not in checkpoint:
checkpoint = {k: checkpoint}
k, _ = list(self.load_dict.items())[0]
# single object and checkpoint is directly a state_dict
if len(self.load_dict) == 1 and k not in checkpoint:
checkpoint = {k: checkpoint}

# skip items that don't match data shape
if not self.strict_shape:
pop_items: List[str] = []
for k, obj in self.load_dict.items():
checkpoint[k] = copy_model_state(obj, checkpoint, inplace=False)[0]
if isinstance(obj, torch.nn.Module):
# skip items that don't match key name or data shape
checkpoint[k] = copy_model_state(obj, checkpoint, inplace=False)[0]
Comment thread
Nic-Ma marked this conversation as resolved.
else:
warnings.warn("`strict_shape` is False, load checkpoint for model, skip others in `load_dict`.")
Comment thread
Nic-Ma marked this conversation as resolved.
pop_items.append(k)
for i in pop_items:
self.load_dict.pop(i)

# save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint
prior_max_epochs = engine.state.max_epochs
Expand Down
18 changes: 15 additions & 3 deletions tests/test_handler_checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,22 +153,34 @@ def test_strict_shape(self):
data1["0.weight"] = torch.tensor([1, 2, 3, 4, 5])
data1["new"] = torch.tensor(0.1)
net1.load_state_dict(data1, strict=False)
opt1 = optim.SGD(net1.parameters(), lr=0.02)

net2 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()])
data2 = net2.state_dict()
data2["0.weight"] = torch.tensor([0.2])
data2["1.weight"] = torch.tensor([0.3])
net2.load_state_dict(data2)
opt2 = optim.SGD(net2.parameters(), lr=0.02)

with tempfile.TemporaryDirectory() as tempdir:
engine = Engine(lambda e, b: None)
CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine)
CheckpointSaver(save_dir=tempdir, save_dict={"net": net1, "opt": opt1}, save_final=True).attach(engine)
engine.run([0] * 8, max_epochs=5)
path = tempdir + "/net_final_iteration=40.pt"
path = tempdir + "/checkpoint_final_iteration=40.pt"
engine = Engine(lambda e, b: None)
CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=False, strict_shape=False).attach(engine)
CheckpointLoader(
load_path=path,
# expect to print a warning because it loads not only `net` but also `opt` with `strict_shape=False`
load_dict={"net": net2, "opt": opt2},
strict=False,
strict_shape=False,
).attach(engine)
engine.run([0] * 8, max_epochs=1)
torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.2]))
# test whether `opt2` had been skipped when loading with `strict_shape=False`,
# it should have 2 items in `params`(0.weight and 1.weight) while the checkpoint has 1 item(0.weight)
self.assertEqual(len(opt1.state_dict()["param_groups"][0]["params"]), 1)
self.assertEqual(len(opt2.state_dict()["param_groups"][0]["params"]), 2)


if __name__ == "__main__":
Expand Down