Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 70 additions & 16 deletions monai/apps/mmars/mmars.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html
"""

import json
import os
import warnings
from typing import Mapping
Expand Down Expand Up @@ -89,7 +90,15 @@ def download_mmar(item, mmar_dir=None, progress: bool = True):
return model_dir


def load_from_mmar(item, mmar_dir=None, progress: bool = True, map_location=None, pretrained=True, weights_only=False):
def load_from_mmar(
item,
mmar_dir=None,
progress: bool = True,
map_location=None,
pretrained=True,
weights_only=False,
model_key: str = "model",
):
"""
Download and extract Medical Model Archive (MMAR) model weights from Nvidia Clara Train.

Expand All @@ -100,6 +109,9 @@ def load_from_mmar(item, mmar_dir=None, progress: bool = True, map_location=None
map_location: pytorch API parameter for `torch.load` or `torch.jit.load`.
pretrained: whether to load the pretrained weights after initializing a network module.
weights_only: whether to load only the weights instead of initializing the network module and assign weights.
model_key: a key to search in the model file or config file for the model dictionary.
Currently this function assumes that the model dictionary has
`{"[name|path]": "test.module", "args": {'kw': 'test'}}`.

Examples::
>>> from monai.apps import load_from_mmar
Expand All @@ -126,27 +138,69 @@ def load_from_mmar(item, mmar_dir=None, progress: bool = True, map_location=None
# loading with `torch.load`
model_dict = torch.load(model_file, map_location=map_location)
if weights_only:
return model_dict["model"]

# TODO: search for the module based on model name?
if not model_dict.get("train_conf", ""):
raise ValueError("The MMAR configuration does not have a 'train_conf' section.")
model_config = model_dict["train_conf"]["train"]["model"]
if model_config.get("name", ""): # model config section is a "name"
return model_dict.get(model_key, model_dict) # model_dict[model_key] or model_dict directly

# 1. search `model_dict['train_config]` for model config spec.
model_config = _get_val(dict(model_dict).get("train_conf", {}), key=model_key, default={})
if not model_config:
# 2. search json CONFIG_FILE for model config spec.
json_path = os.path.join(model_dir, item.get(Keys.CONFIG_FILE, "config_train.json"))
with open(json_path) as f:
conf_dict = json.load(f)
conf_dict = dict(conf_dict)
model_config = _get_val(conf_dict, key=model_key, default={})
if not model_config:
# 3. search `model_dict` for model config spec.
model_config = _get_val(dict(model_dict), key=model_key, default={})
Comment thread
Nic-Ma marked this conversation as resolved.

if not (model_config and isinstance(model_config, Mapping)):
raise ValueError(
f"Could not load model config dictionary from config: {item.get(Keys.CONFIG_FILE)}, "
f"or from model file: {item.get(Keys.MODEL_FILE)}."
)

# parse `model_config` for model class and model parameters
if model_config.get("name"): # model config section is a "name"
model_name = model_config["name"]
model_cls = monai_nets.__dict__[model_name]
else: # model config section is a "path"
elif model_config.get("path"): # model config section is a "path"
# https://docs.nvidia.com/clara/clara-train-sdk/pt/byom.html
model_module, model_name = model_config.get("path", "").rsplit(".", 1)
model_module, model_name = model_config.get("path", ".").rsplit(".", 1)
Comment thread
wyli marked this conversation as resolved.
model_cls, has_cls = optional_import(module=model_module, name=model_name)
if not has_cls:
raise ValueError(f"Could not load model config {model_config.get('path', '')}.")
model_kwargs = model_config["args"]
model_inst = model_cls(**model_kwargs)
raise ValueError(
f"Could not load MMAR model config {model_config.get('path', '')}, "
f"Please make sure MMAR's sub-folders in '{model_dir}' is on the PYTHONPATH."
"See also: https://docs.nvidia.com/clara/clara-train-sdk/pt/byom.html"
)
else:
raise ValueError(f"Could not load model config {model_config}.")

print(f"*** Model: {model_cls}")
print(f"*** Model param: {model_kwargs}")
model_kwargs = model_config.get("args", None)
if model_kwargs:
model_inst = model_cls(**model_kwargs)
print(f"*** Model params: {model_kwargs}")
else:
model_inst = model_cls()
if pretrained:
model_inst.load_state_dict(model_dict["model"])
model_inst.load_state_dict(model_dict.get(model_key, model_dict))
print("\n---")
print(f"For more information, please visit {item['doc']}\n")
print(f"For more information, please visit {item[Keys.DOC]}\n")
return model_inst


def _get_val(input_dict: Mapping, key="model", default=None):
"""
Search for the item with `key` in `config_dict`.
Returns: the first occurrence of `key` in a breadth first search.
"""
if key in input_dict:
return input_dict[key]
for sub_dict in input_dict:
val = input_dict[sub_dict]
if isinstance(val, Mapping):
found_val = _get_val(val, key=key, default=None)
if found_val is not None:
return found_val
return default
13 changes: 13 additions & 0 deletions monai/apps/mmars/model_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class RemoteMMARKeys:
HASH_TYPE = "hash_type" # hashing method for the compressed MMAR
HASH_VAL = "hash_val" # hashing value for the compressed MMAR
MODEL_FILE = "model_file" # within an MMAR folder, the relative path to the model file
CONFIG_FILE = "config_file" # within an MMAR folder, the relative path to the config file (for model config)


MODEL_DESC = (
Expand Down Expand Up @@ -72,4 +73,16 @@ class RemoteMMARKeys:
RemoteMMARKeys.HASH_VAL: None,
RemoteMMARKeys.MODEL_FILE: os.path.join("models", "server", "best_FL_global_model.pt"),
},
{
RemoteMMARKeys.ID: "clara_pt_pathology_metastasis_detection_1",
RemoteMMARKeys.NAME: "clara_pt_pathology_metastasis_detection",
RemoteMMARKeys.URL: "https://api.ngc.nvidia.com/v2/models/nvidia/"
"med/clara_pt_pathology_metastasis_detection/versions/1/zip",
RemoteMMARKeys.DOC: "https://ngc.nvidia.com/catalog/models/nvidia:med:clara_pt_pathology_metastasis_detection",
RemoteMMARKeys.FILE_TYPE: "zip",
RemoteMMARKeys.HASH_TYPE: "md5",
RemoteMMARKeys.HASH_VAL: None,
RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
RemoteMMARKeys.CONFIG_FILE: os.path.join("config", "config_train.json"),
},
)
25 changes: 25 additions & 0 deletions tests/test_mmar_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from monai.apps import download_mmar, load_from_mmar
from monai.apps.mmars import MODEL_DESC
from monai.apps.mmars.mmars import _get_val
from tests.utils import SkipIfAtLeastPyTorchVersion, SkipIfBeforePyTorchVersion, skip_if_quick

TEST_CASES = [["clara_pt_prostate_mri_segmentation_1"], ["clara_pt_covid19_ct_lesion_segmentation_1"]]
Expand Down Expand Up @@ -78,6 +79,24 @@
]
),
),
(
{
"item": "clara_pt_pathology_metastasis_detection_1",
"map_location": "cuda" if torch.cuda.is_available() else "cpu",
},
"TorchVisionFullyConvModel",
np.array(
[
[-0.00693138, -0.00441378, -0.01057985, 0.05604396, 0.03526996, -0.00399302, -0.0267504],
[0.00805358, 0.01016939, -0.10749951, -0.28787708, -0.27905375, -0.13328083, -0.00882593],
[-0.01909848, 0.04871106, 0.2957697, 0.60376877, 0.53552634, 0.24821444, 0.03773781],
[0.02449462, -0.07471243, -0.30943492, -0.43987238, -0.26549947, -0.00698426, 0.04395606],
[-0.03124012, 0.00807883, 0.06797771, -0.04612541, -0.30266526, -0.39722857, -0.25109962],
[0.02480375, 0.03378576, 0.06519791, 0.24546203, 0.41867673, 0.393786, 0.16055048],
[-0.01529332, -0.00062494, -0.016658, -0.06313603, -0.1508078, -0.09107386, -0.01239121],
]
),
),
]


Expand Down Expand Up @@ -124,6 +143,12 @@ def test_no_default(self):
with self.assertRaises(ValueError):
download_mmar(0)

def test_search(self):
self.assertEqual(_get_val({"a": 1, "b": 2}, key="b"), 2)
self.assertEqual(_get_val({"a": {"c": {"c": 4}}, "b": {"c": 2}}, key="b"), {"c": 2})
self.assertEqual(_get_val({"a": {"c": 4}, "b": {"c": 2}}, key="c"), 4)
self.assertEqual(_get_val({"a": {"c": None}, "b": {"c": 2}}, key="c"), 2)


if __name__ == "__main__":
unittest.main()