diff --git a/monai/apps/mmars/mmars.py b/monai/apps/mmars/mmars.py index 55b57dab6d..5f97729ee8 100644 --- a/monai/apps/mmars/mmars.py +++ b/monai/apps/mmars/mmars.py @@ -16,6 +16,7 @@ - https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html """ +import json import os import warnings from typing import Mapping @@ -89,7 +90,15 @@ def download_mmar(item, mmar_dir=None, progress: bool = True): return model_dir -def load_from_mmar(item, mmar_dir=None, progress: bool = True, map_location=None, pretrained=True, weights_only=False): +def load_from_mmar( + item, + mmar_dir=None, + progress: bool = True, + map_location=None, + pretrained=True, + weights_only=False, + model_key: str = "model", +): """ Download and extract Medical Model Archive (MMAR) model weights from Nvidia Clara Train. @@ -100,6 +109,9 @@ def load_from_mmar(item, mmar_dir=None, progress: bool = True, map_location=None map_location: pytorch API parameter for `torch.load` or `torch.jit.load`. pretrained: whether to load the pretrained weights after initializing a network module. weights_only: whether to load only the weights instead of initializing the network module and assign weights. + model_key: a key to search in the model file or config file for the model dictionary. + Currently this function assumes that the model dictionary has + `{"[name|path]": "test.module", "args": {'kw': 'test'}}`. Examples:: >>> from monai.apps import load_from_mmar @@ -126,27 +138,69 @@ def load_from_mmar(item, mmar_dir=None, progress: bool = True, map_location=None # loading with `torch.load` model_dict = torch.load(model_file, map_location=map_location) if weights_only: - return model_dict["model"] - - # TODO: search for the module based on model name? - if not model_dict.get("train_conf", ""): - raise ValueError("The MMAR configuration does not have a 'train_conf' section.") - model_config = model_dict["train_conf"]["train"]["model"] - if model_config.get("name", ""): # model config section is a "name" + return model_dict.get(model_key, model_dict) # model_dict[model_key] or model_dict directly + + # 1. search `model_dict['train_config]` for model config spec. + model_config = _get_val(dict(model_dict).get("train_conf", {}), key=model_key, default={}) + if not model_config: + # 2. search json CONFIG_FILE for model config spec. + json_path = os.path.join(model_dir, item.get(Keys.CONFIG_FILE, "config_train.json")) + with open(json_path) as f: + conf_dict = json.load(f) + conf_dict = dict(conf_dict) + model_config = _get_val(conf_dict, key=model_key, default={}) + if not model_config: + # 3. search `model_dict` for model config spec. + model_config = _get_val(dict(model_dict), key=model_key, default={}) + + if not (model_config and isinstance(model_config, Mapping)): + raise ValueError( + f"Could not load model config dictionary from config: {item.get(Keys.CONFIG_FILE)}, " + f"or from model file: {item.get(Keys.MODEL_FILE)}." + ) + + # parse `model_config` for model class and model parameters + if model_config.get("name"): # model config section is a "name" model_name = model_config["name"] model_cls = monai_nets.__dict__[model_name] - else: # model config section is a "path" + elif model_config.get("path"): # model config section is a "path" # https://docs.nvidia.com/clara/clara-train-sdk/pt/byom.html - model_module, model_name = model_config.get("path", "").rsplit(".", 1) + model_module, model_name = model_config.get("path", ".").rsplit(".", 1) model_cls, has_cls = optional_import(module=model_module, name=model_name) if not has_cls: - raise ValueError(f"Could not load model config {model_config.get('path', '')}.") - model_kwargs = model_config["args"] - model_inst = model_cls(**model_kwargs) + raise ValueError( + f"Could not load MMAR model config {model_config.get('path', '')}, " + f"Please make sure MMAR's sub-folders in '{model_dir}' is on the PYTHONPATH." + "See also: https://docs.nvidia.com/clara/clara-train-sdk/pt/byom.html" + ) + else: + raise ValueError(f"Could not load model config {model_config}.") + print(f"*** Model: {model_cls}") - print(f"*** Model param: {model_kwargs}") + model_kwargs = model_config.get("args", None) + if model_kwargs: + model_inst = model_cls(**model_kwargs) + print(f"*** Model params: {model_kwargs}") + else: + model_inst = model_cls() if pretrained: - model_inst.load_state_dict(model_dict["model"]) + model_inst.load_state_dict(model_dict.get(model_key, model_dict)) print("\n---") - print(f"For more information, please visit {item['doc']}\n") + print(f"For more information, please visit {item[Keys.DOC]}\n") return model_inst + + +def _get_val(input_dict: Mapping, key="model", default=None): + """ + Search for the item with `key` in `config_dict`. + Returns: the first occurrence of `key` in a breadth first search. + """ + if key in input_dict: + return input_dict[key] + for sub_dict in input_dict: + val = input_dict[sub_dict] + if isinstance(val, Mapping): + found_val = _get_val(val, key=key, default=None) + if found_val is not None: + return found_val + return default diff --git a/monai/apps/mmars/model_desc.py b/monai/apps/mmars/model_desc.py index 5a9c824e82..c05f8cb51a 100644 --- a/monai/apps/mmars/model_desc.py +++ b/monai/apps/mmars/model_desc.py @@ -35,6 +35,7 @@ class RemoteMMARKeys: HASH_TYPE = "hash_type" # hashing method for the compressed MMAR HASH_VAL = "hash_val" # hashing value for the compressed MMAR MODEL_FILE = "model_file" # within an MMAR folder, the relative path to the model file + CONFIG_FILE = "config_file" # within an MMAR folder, the relative path to the config file (for model config) MODEL_DESC = ( @@ -72,4 +73,16 @@ class RemoteMMARKeys: RemoteMMARKeys.HASH_VAL: None, RemoteMMARKeys.MODEL_FILE: os.path.join("models", "server", "best_FL_global_model.pt"), }, + { + RemoteMMARKeys.ID: "clara_pt_pathology_metastasis_detection_1", + RemoteMMARKeys.NAME: "clara_pt_pathology_metastasis_detection", + RemoteMMARKeys.URL: "https://api.ngc.nvidia.com/v2/models/nvidia/" + "med/clara_pt_pathology_metastasis_detection/versions/1/zip", + RemoteMMARKeys.DOC: "https://ngc.nvidia.com/catalog/models/nvidia:med:clara_pt_pathology_metastasis_detection", + RemoteMMARKeys.FILE_TYPE: "zip", + RemoteMMARKeys.HASH_TYPE: "md5", + RemoteMMARKeys.HASH_VAL: None, + RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"), + RemoteMMARKeys.CONFIG_FILE: os.path.join("config", "config_train.json"), + }, ) diff --git a/tests/test_mmar_download.py b/tests/test_mmar_download.py index 31c73b8b8f..ee1981d6d6 100644 --- a/tests/test_mmar_download.py +++ b/tests/test_mmar_download.py @@ -20,6 +20,7 @@ from monai.apps import download_mmar, load_from_mmar from monai.apps.mmars import MODEL_DESC +from monai.apps.mmars.mmars import _get_val from tests.utils import SkipIfAtLeastPyTorchVersion, SkipIfBeforePyTorchVersion, skip_if_quick TEST_CASES = [["clara_pt_prostate_mri_segmentation_1"], ["clara_pt_covid19_ct_lesion_segmentation_1"]] @@ -78,6 +79,24 @@ ] ), ), + ( + { + "item": "clara_pt_pathology_metastasis_detection_1", + "map_location": "cuda" if torch.cuda.is_available() else "cpu", + }, + "TorchVisionFullyConvModel", + np.array( + [ + [-0.00693138, -0.00441378, -0.01057985, 0.05604396, 0.03526996, -0.00399302, -0.0267504], + [0.00805358, 0.01016939, -0.10749951, -0.28787708, -0.27905375, -0.13328083, -0.00882593], + [-0.01909848, 0.04871106, 0.2957697, 0.60376877, 0.53552634, 0.24821444, 0.03773781], + [0.02449462, -0.07471243, -0.30943492, -0.43987238, -0.26549947, -0.00698426, 0.04395606], + [-0.03124012, 0.00807883, 0.06797771, -0.04612541, -0.30266526, -0.39722857, -0.25109962], + [0.02480375, 0.03378576, 0.06519791, 0.24546203, 0.41867673, 0.393786, 0.16055048], + [-0.01529332, -0.00062494, -0.016658, -0.06313603, -0.1508078, -0.09107386, -0.01239121], + ] + ), + ), ] @@ -124,6 +143,12 @@ def test_no_default(self): with self.assertRaises(ValueError): download_mmar(0) + def test_search(self): + self.assertEqual(_get_val({"a": 1, "b": 2}, key="b"), 2) + self.assertEqual(_get_val({"a": {"c": {"c": 4}}, "b": {"c": 2}}, key="b"), {"c": 2}) + self.assertEqual(_get_val({"a": {"c": 4}, "b": {"c": 2}}, key="c"), 4) + self.assertEqual(_get_val({"a": {"c": None}, "b": {"c": 2}}, key="c"), 2) + if __name__ == "__main__": unittest.main()