diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index b64f26a091..9ae0fae6c6 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -14,9 +14,10 @@ import os import pprint import re +import warnings from logging.config import fileConfig from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple, Union +from typing import Dict, Mapping, Optional, Sequence, Tuple, Union import torch from torch.cuda import is_available @@ -26,7 +27,7 @@ from monai.bundle.config_parser import ConfigParser from monai.config import IgniteInfo, PathLike from monai.data import load_net_with_metadata, save_net_with_metadata -from monai.networks import convert_to_torchscript, copy_model_state +from monai.networks import convert_to_torchscript, copy_model_state, get_state_dict from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import from monai.utils.misc import ensure_tuple @@ -151,7 +152,7 @@ def download( name: Optional[str] = None, bundle_dir: Optional[PathLike] = None, source: str = "github", - repo: Optional[str] = None, + repo: str = "Project-MONAI/model-zoo/hosting_storage_v1", url: Optional[str] = None, progress: bool = True, args_file: Optional[str] = None, @@ -182,12 +183,11 @@ def download( Args: name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`. bundle_dir: target directory to store the downloaded data. - Default is `bundle` subfolder under`torch.hub get_dir()`. - source: place that saved the bundle. - If `source` is `github`, the bundle should be within the releases. - repo: repo name. If `None` and `url` is `None`, it must be provided in `args_file`. - If `source` is `github`, it should be in the form of `repo_owner/repo_name/release_tag`. - For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`. + Default is `bundle` subfolder under `torch.hub.get_dir()`. + source: storage location name. This argument is used when `url` is `None`. + "github" is currently the only supported value. + repo: repo name. This argument is used when `url` is `None`. + If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag". url: url to download the data. If not `None`, data will be downloaded directly and `source` will not be checked. If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`. @@ -201,8 +201,8 @@ def download( ) _log_input_summary(tag="download", args=_args) - name_, bundle_dir_, source_, repo_, url_, progress_ = _pop_args( - _args, name=None, bundle_dir=None, source="github", repo=None, url=None, progress=True + source_, repo_, progress_, name_, bundle_dir_, url_ = _pop_args( + _args, "source", "repo", "progress", name=None, bundle_dir=None, url=None ) bundle_dir_ = _process_bundle_dir(bundle_dir_) @@ -215,10 +215,8 @@ def download( download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_) extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True) elif source_ == "github": - if name_ is None or repo_ is None: - raise ValueError( - f"To download from source: Github, `name` and `repo` must be provided, got {name_} and {repo_}." - ) + if name_ is None: + raise ValueError(f"To download from source: Github, `name` must be provided, got {name_}.") _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_) else: raise NotImplementedError( @@ -232,9 +230,10 @@ def load( load_ts_module: bool = False, bundle_dir: Optional[PathLike] = None, source: str = "github", - repo: Optional[str] = None, + repo: str = "Project-MONAI/model-zoo/hosting_storage_v1", progress: bool = True, device: Optional[str] = None, + key_in_ckpt: Optional[str] = None, config_files: Sequence[str] = (), net_name: Optional[str] = None, **net_kwargs, @@ -247,15 +246,16 @@ def load( model_file: the relative path of the model weights or TorchScript module within bundle. If `None`, "models/model.pt" or "models/model.ts" will be used. load_ts_module: a flag to specify if loading the TorchScript module. - bundle_dir: the directory the weights/TorchScript module will be loaded from. - Default is `bundle` subfolder under`torch.hub get_dir()`. - source: the place that saved the bundle. - If `source` is `github`, the bundle should be within the releases. - repo: the repo name. If the weights file does not exist locally and `url` is `None`, it must be provided. - If `source` is `github`, it should be in the form of `repo_owner/repo_name/release_tag`. - For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`. + bundle_dir: directory the weights/TorchScript module will be loaded from. + Default is `bundle` subfolder under `torch.hub.get_dir()`. + source: storage location name. This argument is used when `model_file` is not existing locally and need to be + downloaded first. "github" is currently the only supported value. + repo: repo name. This argument is used when `model_file` is not existing locally and need to be + downloaded first. If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag". progress: whether to display a progress bar when downloading. device: target device of returned weights or module, if `None`, prefer to "cuda" if existing. + key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model + weights. if not nested checkpoint, no need to set. config_files: extra filenames would be loaded. The argument only works when loading a TorchScript module, see `_extra_files` in `torch.jit.load` for more details. net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights. @@ -286,6 +286,9 @@ def load( return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files) # loading with `torch.load` model_dict = torch.load(full_path, map_location=torch.device(device)) + if not isinstance(model_dict, Mapping): + warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.") + model_dict = get_state_dict(model_dict) if net_name is None: return model_dict @@ -293,7 +296,7 @@ def load( configer = ConfigComponent(config=net_kwargs) model = configer.instantiate() model.to(device) # type: ignore - model.load_state_dict(model_dict) # type: ignore + copy_model_state(dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt]) # type: ignore return model