Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions monai/apps/mmars/mmars.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import monai.networks.nets as monai_nets
from monai.apps.utils import download_and_extract, logger
from monai.config.type_definitions import PathLike
from monai.networks.utils import copy_model_state
from monai.utils.module import optional_import

from .model_desc import MODEL_DESC
Expand Down Expand Up @@ -243,7 +244,7 @@ def load_from_mmar(

# 1. search `model_dict['train_config]` for model config spec.
model_config = _get_val(dict(model_dict).get("train_conf", {}), key=model_key, default={})
if not model_config:
if not model_config or not isinstance(model_config, Mapping):
# 2. search json CONFIG_FILE for model config spec.
json_path = model_dir / item.get(Keys.CONFIG_FILE, os.path.join("config", "config_train.json"))
with open(json_path) as f:
Expand Down Expand Up @@ -285,7 +286,9 @@ def load_from_mmar(
else:
model_inst = model_cls()
if pretrained:
model_inst.load_state_dict(model_dict.get(model_key, model_dict))
_, changed, unchanged = copy_model_state(model_inst, model_dict.get(model_key, model_dict), inplace=True)
if not (changed and not unchanged): # not all model_inst varaibles are changed
logger.warning(f"*** Loading model state -- unchanged: {len(unchanged)}, changed: {len(changed)}.")
Comment thread
wyli marked this conversation as resolved.
logger.info("\n---")
doc_url = item.get(Keys.DOC) or _get_ngc_doc_url(item[Keys.NAME], model_prefix="nvidia:med:")
logger.info(f"For more information, please visit {doc_url}\n")
Expand Down
33 changes: 32 additions & 1 deletion monai/apps/mmars/model_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import os
from typing import Any, Dict, Tuple

__all__ = ["MODEL_DESC", "RemoteMMARKeys"]

Expand All @@ -39,7 +40,7 @@ class RemoteMMARKeys:
VERSION = "version" # version of the MMAR


MODEL_DESC = (
MODEL_DESC: Tuple[Dict[Any, Any], ...] = (
{
RemoteMMARKeys.ID: "clara_pt_spleen_ct_segmentation_1",
RemoteMMARKeys.NAME: "clara_pt_spleen_ct_segmentation",
Expand Down Expand Up @@ -194,4 +195,34 @@ class RemoteMMARKeys:
RemoteMMARKeys.CONFIG_FILE: os.path.join("config", "config_train.json"),
RemoteMMARKeys.VERSION: 1,
},
{
RemoteMMARKeys.ID: "clara_pt_unetr_ct_btcv_segmentation",
RemoteMMARKeys.NAME: "clara_pt_unetr_ct_btcv_segmentation",
RemoteMMARKeys.FILE_TYPE: "zip",
RemoteMMARKeys.HASH_TYPE: "md5",
RemoteMMARKeys.HASH_VAL: None,
RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
RemoteMMARKeys.CONFIG_FILE: os.path.join("config", "config_train.json"),
RemoteMMARKeys.VERSION: 4.1,
},
{
RemoteMMARKeys.ID: "clara_pt_chest_xray_classification",
RemoteMMARKeys.NAME: "clara_pt_chest_xray_classification",
RemoteMMARKeys.FILE_TYPE: "zip",
RemoteMMARKeys.HASH_TYPE: "md5",
RemoteMMARKeys.HASH_VAL: None,
RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
RemoteMMARKeys.CONFIG_FILE: os.path.join("config", "config_train.json"),
RemoteMMARKeys.VERSION: 4.1,
},
{
RemoteMMARKeys.ID: "clara_pt_self_supervised_learning_segmentation",
RemoteMMARKeys.NAME: "clara_pt_self_supervised_learning_segmentation",
RemoteMMARKeys.FILE_TYPE: "zip",
RemoteMMARKeys.HASH_TYPE: "md5",
RemoteMMARKeys.HASH_VAL: None,
RemoteMMARKeys.MODEL_FILE: os.path.join("models_2gpu", "best_metric_model.pt"),
RemoteMMARKeys.CONFIG_FILE: os.path.join("config", "config_train.json"),
RemoteMMARKeys.VERSION: 4.1,
},
)
4 changes: 2 additions & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ flake8-bugbear
flake8-comprehensions
flake8-executable
flake8-pyi
pylint
pylint!=2.13 # https://github.com/PyCQA/pylint/issues/5969
mccabe
pep8-naming
pycodestyle
Expand All @@ -32,7 +32,7 @@ Sphinx==3.5.3
recommonmark==0.6.0
sphinx-autodoc-typehints==1.11.1
sphinx-rtd-theme==0.5.2
cucim>=21.8.2; platform_system == "Linux"
cucim==22.2.0; platform_system == "Linux"
openslide-python==1.1.2
imagecodecs; platform_system == "Linux"
tifffile; platform_system == "Linux"
Expand Down
2 changes: 2 additions & 0 deletions tests/ngc_mmar_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def setUp(self):
def test_loading_mmar(self, item):
if item["name"] == "clara_pt_fed_learning_brain_tumor_mri_segmentation":
default_model_file = os.path.join("models", "server", "best_FL_global_model.pt")
elif item["name"] == "clara_pt_self_supervised_learning_segmentation":
default_model_file = os.path.join("models_2gpu", "best_metric_model.pt")
else:
default_model_file = None
pretrained_model = load_from_mmar(
Expand Down
36 changes: 24 additions & 12 deletions tests/test_mmar_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,19 @@
np.array(
[
[
[-0.21147135, 0.10815059, -0.04733997],
[-0.3425553, 0.03304602, 0.113512],
[0.1278807, 0.26298857, -0.0583012],
[0.01671106, 0.08502351, -0.1766469],
[-0.13039736, -0.06137804, 0.03924942],
[0.02268324, 0.159056, -0.03485069],
],
[
[-0.3658006, -0.14725913, 0.01149207],
[-0.5453718, -0.12894264, -0.05492746],
[0.16887102, 0.17586298, 0.03977356],
[0.04788467, -0.09365353, -0.05802464],
[-0.19500689, -0.13514304, -0.08191573],
[0.0238207, 0.08029253, 0.10818923],
],
[
[-0.12767333, -0.07876065, 0.03136465],
[0.26057404, -0.03538669, 0.07552322],
[0.23879515, 0.04919613, 0.01725162],
[-0.11541673, -0.10622888, 0.039689],
[0.18462701, -0.0499289, 0.14309818],
[0.00528282, 0.02152331, 0.1698219],
],
]
),
Expand All @@ -71,9 +71,21 @@
"SegResNet",
np.array(
[
[[-0.0839, 0.0715, -0.0760], [0.0645, 0.1186, 0.0218], [0.0303, 0.0631, -0.0648]],
[[0.0128, 0.1440, 0.0213], [0.1658, 0.1813, 0.0541], [-0.0627, 0.0839, 0.0660]],
[[-0.1207, 0.0138, -0.0808], [0.0277, 0.0416, 0.0597], [0.0455, -0.0134, -0.0949]],
[
[0.01874463, 0.12237817, 0.09269974],
[0.07691482, 0.00621202, -0.06682577],
[-0.07718472, 0.08637864, -0.03222707],
],
[
[0.05117761, 0.07428649, -0.03053505],
[0.11045473, 0.07083791, 0.06547518],
[0.09555705, -0.03950734, -0.00819483],
],
[
[0.03704128, 0.062543, 0.0380853],
[-0.02814676, -0.03078287, -0.01383446],
[-0.08137762, 0.01385882, 0.01229484],
],
]
),
),
Expand Down