Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 27 additions & 24 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
import os
import pprint
import re
import warnings
from logging.config import fileConfig
from pathlib import Path
from typing import Dict, Optional, Sequence, Tuple, Union
from typing import Dict, Mapping, Optional, Sequence, Tuple, Union

import torch
from torch.cuda import is_available
Expand All @@ -26,7 +27,7 @@
from monai.bundle.config_parser import ConfigParser
from monai.config import IgniteInfo, PathLike
from monai.data import load_net_with_metadata, save_net_with_metadata
from monai.networks import convert_to_torchscript, copy_model_state
from monai.networks import convert_to_torchscript, copy_model_state, get_state_dict
from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import
from monai.utils.misc import ensure_tuple

Expand Down Expand Up @@ -151,7 +152,7 @@ def download(
name: Optional[str] = None,
bundle_dir: Optional[PathLike] = None,
source: str = "github",
repo: Optional[str] = None,
repo: str = "Project-MONAI/model-zoo/hosting_storage_v1",
Comment thread
yiheng-wang-nv marked this conversation as resolved.
url: Optional[str] = None,
progress: bool = True,
args_file: Optional[str] = None,
Expand Down Expand Up @@ -182,12 +183,11 @@ def download(
Args:
name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.
bundle_dir: target directory to store the downloaded data.
Default is `bundle` subfolder under`torch.hub get_dir()`.
source: place that saved the bundle.
If `source` is `github`, the bundle should be within the releases.
repo: repo name. If `None` and `url` is `None`, it must be provided in `args_file`.
Comment thread
yiheng-wang-nv marked this conversation as resolved.
If `source` is `github`, it should be in the form of `repo_owner/repo_name/release_tag`.
For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`.
Default is `bundle` subfolder under `torch.hub.get_dir()`.
source: storage location name. This argument is used when `url` is `None`.
"github" is currently the only supported value.
repo: repo name. This argument is used when `url` is `None`.
If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
url: url to download the data. If not `None`, data will be downloaded directly
and `source` will not be checked.
If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
Expand All @@ -201,8 +201,8 @@ def download(
)

_log_input_summary(tag="download", args=_args)
name_, bundle_dir_, source_, repo_, url_, progress_ = _pop_args(
_args, name=None, bundle_dir=None, source="github", repo=None, url=None, progress=True
source_, repo_, progress_, name_, bundle_dir_, url_ = _pop_args(
_args, "source", "repo", "progress", name=None, bundle_dir=None, url=None
)

bundle_dir_ = _process_bundle_dir(bundle_dir_)
Expand All @@ -215,10 +215,8 @@ def download(
download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_)
extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True)
elif source_ == "github":
if name_ is None or repo_ is None:
raise ValueError(
f"To download from source: Github, `name` and `repo` must be provided, got {name_} and {repo_}."
)
if name_ is None:
raise ValueError(f"To download from source: Github, `name` must be provided, got {name_}.")
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_)
else:
raise NotImplementedError(
Expand All @@ -232,9 +230,10 @@ def load(
load_ts_module: bool = False,
bundle_dir: Optional[PathLike] = None,
source: str = "github",
repo: Optional[str] = None,
repo: str = "Project-MONAI/model-zoo/hosting_storage_v1",
progress: bool = True,
device: Optional[str] = None,
key_in_ckpt: Optional[str] = None,
config_files: Sequence[str] = (),
net_name: Optional[str] = None,
**net_kwargs,
Expand All @@ -247,15 +246,16 @@ def load(
model_file: the relative path of the model weights or TorchScript module within bundle.
If `None`, "models/model.pt" or "models/model.ts" will be used.
load_ts_module: a flag to specify if loading the TorchScript module.
bundle_dir: the directory the weights/TorchScript module will be loaded from.
Default is `bundle` subfolder under`torch.hub get_dir()`.
source: the place that saved the bundle.
If `source` is `github`, the bundle should be within the releases.
repo: the repo name. If the weights file does not exist locally and `url` is `None`, it must be provided.
If `source` is `github`, it should be in the form of `repo_owner/repo_name/release_tag`.
For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`.
bundle_dir: directory the weights/TorchScript module will be loaded from.
Default is `bundle` subfolder under `torch.hub.get_dir()`.
source: storage location name. This argument is used when `model_file` is not existing locally and need to be
downloaded first. "github" is currently the only supported value.
repo: repo name. This argument is used when `model_file` is not existing locally and need to be
downloaded first. If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
progress: whether to display a progress bar when downloading.
device: target device of returned weights or module, if `None`, prefer to "cuda" if existing.
key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model
weights. if not nested checkpoint, no need to set.
config_files: extra filenames would be loaded. The argument only works when loading a TorchScript module,
see `_extra_files` in `torch.jit.load` for more details.
net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights.
Expand Down Expand Up @@ -286,14 +286,17 @@ def load(
return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
# loading with `torch.load`
model_dict = torch.load(full_path, map_location=torch.device(device))
if not isinstance(model_dict, Mapping):
warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.")
model_dict = get_state_dict(model_dict)

if net_name is None:
return model_dict
net_kwargs["_target_"] = net_name
configer = ConfigComponent(config=net_kwargs)
model = configer.instantiate()
model.to(device) # type: ignore
model.load_state_dict(model_dict) # type: ignore
copy_model_state(dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt]) # type: ignore
Comment thread
wyli marked this conversation as resolved.
return model


Expand Down