diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst
index 0283cb909e..87c4bf36d2 100644
--- a/docs/source/bundle.rst
+++ b/docs/source/bundle.rst
@@ -37,3 +37,4 @@ Model Bundle
---------
.. autofunction:: run
.. autofunction:: verify_metadata
+.. autofunction:: verify_net_in_out
diff --git a/docs/source/mb_specification.rst b/docs/source/mb_specification.rst
index 53976f70db..d383dd7d8e 100644
--- a/docs/source/mb_specification.rst
+++ b/docs/source/mb_specification.rst
@@ -102,7 +102,6 @@ An example JSON metadata file:
"copyright": "Copyright (c) MONAI Consortium",
"data_source": "Task09_Spleen.tar from http://medicaldecathlon.com/",
"data_type": "dicom",
- "dataset_dir": "/workspace/data/Task09_Spleen",
"image_classes": "single channel data, intensity scaled to [0, 1]",
"label_classes": "single channel data, 1 is spleen, 0 is everything else",
"pred_classes": "2 channels OneHot data, channel 1 is spleen, channel 0 is background",
diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py
index 6f84800208..72c8805e9f 100644
--- a/monai/bundle/__init__.py
+++ b/monai/bundle/__init__.py
@@ -12,5 +12,5 @@
from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable
from .config_parser import ConfigParser
from .reference_resolver import ReferenceResolver
-from .scripts import run, verify_metadata
+from .scripts import run, verify_metadata, verify_net_in_out
from .utils import EXPR_KEY, ID_REF_KEY, ID_SEP_KEY, MACRO_KEY
diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py
index 45cd89bfdd..0ff0a476ef 100644
--- a/monai/bundle/__main__.py
+++ b/monai/bundle/__main__.py
@@ -10,7 +10,7 @@
# limitations under the License.
-from monai.bundle.scripts import run, verify_metadata
+from monai.bundle.scripts import run, verify_metadata, verify_net_in_out
if __name__ == "__main__":
from monai.utils import optional_import
diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py
index 1f3165dee3..890d703090 100644
--- a/monai/bundle/scripts.py
+++ b/monai/bundle/scripts.py
@@ -9,14 +9,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import ast
import pprint
import re
-from typing import Dict, Optional, Sequence, Union
+from typing import Dict, Optional, Sequence, Tuple, Union
+
+import torch
from monai.apps.utils import download_url, get_logger
from monai.bundle.config_parser import ConfigParser
from monai.config import PathLike
-from monai.utils import check_parent_dir, optional_import
+from monai.utils import check_parent_dir, get_equivalent_dtype, optional_import
validate, _ = optional_import("jsonschema", name="validate")
ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError")
@@ -51,13 +54,54 @@ def _update_args(args: Optional[Union[str, Dict]] = None, ignore_none: bool = Tr
return args_
-def _log_input_summary(tag: str, args: Dict):
- logger.info(f"\n--- input summary of monai.bundle.scripts.{tag} ---")
+def _log_input_summary(tag, args: Dict):
+ logger.info(f"--- input summary of monai.bundle.scripts.{tag} ---")
for name, val in args.items():
logger.info(f"> {name}: {pprint.pformat(val)}")
logger.info("---\n\n")
+def _get_var_names(expr: str):
+ """
+ Parse the expression and discover what variables are present in it based on ast module.
+
+ Args:
+ expr: source expression to parse.
+
+ """
+ tree = ast.parse(expr)
+ return [m.id for m in ast.walk(tree) if isinstance(m, ast.Name)]
+
+
+def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int = 1, any: int = 1) -> Tuple:
+ """
+ Get spatial shape for fake data according to the specified shape pattern.
+ It supports `int` number and `string` with formats like: "32", "32 * n", "32 ** p", "32 ** p *n".
+
+ Args:
+ shape: specified pattern for the spatial shape.
+ p: power factor to generate fake data shape if dim of expected shape is "x**p", default to 1.
+ p: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1.
+ any: specified size to generate fake data shape if dim of expected shape is "*", default to 1.
+
+ """
+ ret = []
+ for i in shape:
+ if isinstance(i, int):
+ ret.append(i)
+ elif isinstance(i, str):
+ if i == "*":
+ ret.append(any)
+ else:
+ for c in _get_var_names(i):
+ if c not in ["p", "n"]:
+ raise ValueError(f"only support variables 'p' and 'n' so far, but got: {c}.")
+ ret.append(eval(i, {"p": p, "n": n}))
+ else:
+ raise ValueError(f"spatial shape items must be int or string, but got: {type(i)} {i}.")
+ return tuple(ret)
+
+
def run(
runner_id: Optional[str] = None,
meta_file: Optional[Union[str, Sequence[str]]] = None,
@@ -94,8 +138,8 @@ def run(
if it is a list of file paths, the content of them will be merged.
config_file: filepath of the config file, if `None`, must be provided in `args_file`.
if it is a list of file paths, the content of them will be merged.
- args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`,
- `runner_id` and override pairs. so that the command line inputs can be simplified.
+ args_file: a JSON or YAML file to provide default values for `runner_id`, `meta_file`,
+ `config_file`, and override pairs. so that the command line inputs can be simplified.
override: id-value pairs to override or add the corresponding config content.
e.g. ``--net#input_chns 42``.
@@ -172,3 +216,95 @@ def verify_metadata(
logger.info(re.compile(r".*Failed validating", re.S).findall(str(e))[0] + f" against schema `{url}`.")
return
logger.info("metadata is verified with no error.")
+
+
+def verify_net_in_out(
+ net_id: Optional[str] = None,
+ meta_file: Optional[Union[str, Sequence[str]]] = None,
+ config_file: Optional[Union[str, Sequence[str]]] = None,
+ device: Optional[str] = None,
+ p: Optional[int] = None,
+ n: Optional[int] = None,
+ any: Optional[int] = None,
+ args_file: Optional[str] = None,
+ **override,
+):
+ """
+ Verify the input and output data shape and data type of network defined in the metadata.
+ Will test with fake Tensor data according to the required data shape in `metadata`.
+
+ Typical usage examples:
+
+ .. code-block:: bash
+
+ python -m monai.bundle verify_net_in_out network --meta_file --config_file
+
+ Args:
+ net_id: ID name of the network component to verify, it must be `torch.nn.Module`.
+ meta_file: filepath of the metadata file to get network args, if `None`, must be provided in `args_file`.
+ if it is a list of file paths, the content of them will be merged.
+ config_file: filepath of the config file to get network definition, if `None`, must be provided in `args_file`.
+ if it is a list of file paths, the content of them will be merged.
+ device: target device to run the network forward computation, if None, prefer to "cuda" if existing.
+ p: power factor to generate fake data shape if dim of expected shape is "x**p", default to 1.
+ p: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1.
+ any: specified size to generate fake data shape if dim of expected shape is "*", default to 1.
+ args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`,
+ `net_id` and override pairs. so that the command line inputs can be simplified.
+ override: id-value pairs to override or add the corresponding config content.
+ e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.
+
+ """
+
+ _args = _update_args(
+ args=args_file,
+ net_id=net_id,
+ meta_file=meta_file,
+ config_file=config_file,
+ device=device,
+ p=p,
+ n=n,
+ any=any,
+ **override,
+ )
+ _log_input_summary(tag="verify_net_in_out", args=_args)
+
+ parser = ConfigParser()
+ parser.read_config(f=_args.pop("config_file"))
+ parser.read_meta(f=_args.pop("meta_file"))
+ id = _args.pop("net_id", "")
+ device_ = torch.device(_args.pop("device", "cuda:0" if torch.cuda.is_available() else "cpu"))
+ p = _args.pop("p", 1)
+ n = _args.pop("n", 1)
+ any = _args.pop("any", 1)
+
+ # the rest key-values in the _args are to override config content
+ for k, v in _args.items():
+ parser[k] = v
+
+ try:
+ key: str = id # mark the full id when KeyError
+ net = parser.get_parsed_content(key).to(device_)
+ key = "_meta_#network_data_format#inputs#image#num_channels"
+ input_channels = parser[key]
+ key = "_meta_#network_data_format#inputs#image#spatial_shape"
+ input_spatial_shape = tuple(parser[key])
+ key = "_meta_#network_data_format#inputs#image#dtype"
+ input_dtype = get_equivalent_dtype(parser[key], torch.Tensor)
+ key = "_meta_#network_data_format#outputs#pred#num_channels"
+ output_channels = parser[key]
+ key = "_meta_#network_data_format#outputs#pred#dtype"
+ output_dtype = get_equivalent_dtype(parser[key], torch.Tensor)
+ except KeyError as e:
+ raise KeyError(f"Failed to verify due to missing expected key in the config: {key}.") from e
+
+ net.eval()
+ with torch.no_grad():
+ spatial_shape = _get_fake_spatial_shape(input_spatial_shape, p=p, n=n, any=any) # type: ignore
+ test_data = torch.rand(*(1, input_channels, *spatial_shape), dtype=input_dtype, device=device_)
+ output = net(test_data)
+ if output.shape[1] != output_channels:
+ raise ValueError(f"output channel number `{output.shape[1]}` doesn't match: `{output_channels}`.")
+ if output.dtype != output_dtype:
+ raise ValueError(f"dtype of output data `{output.dtype}` doesn't match: {output_dtype}.")
+ logger.info("data shape of network is verified with no error.")
diff --git a/tests/min_tests.py b/tests/min_tests.py
index 7d3b35be47..c0d4f36430 100644
--- a/tests/min_tests.py
+++ b/tests/min_tests.py
@@ -161,6 +161,7 @@ def run_testsuit():
"test_prepare_batch_default_dist",
"test_parallel_execution_dist",
"test_bundle_verify_metadata",
+ "test_bundle_verify_net",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"
diff --git a/tests/test_bundle_verify_metadata.py b/tests/test_bundle_verify_metadata.py
index 7e2bd02209..bf96096c64 100644
--- a/tests/test_bundle_verify_metadata.py
+++ b/tests/test_bundle_verify_metadata.py
@@ -38,7 +38,7 @@ def test_verify(self, meta_file, schema_file):
def_args_file = os.path.join(tempdir, "def_args.json")
ConfigParser.export_config_file(config=def_args, filepath=def_args_file)
- hash_val = "b11acc946148c0186924f8234562b947"
+ hash_val = "e3a7e23d1113a1f3e6c69f09b6f9ce2c"
cmd = [sys.executable, "-m", "monai.bundle", "verify_metadata", "--meta_file", meta_file]
cmd += ["--filepath", schema_file, "--hash_val", hash_val, "--args_file", def_args_file]
@@ -54,7 +54,7 @@ def test_verify_error(self):
json.dump(
{
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/"
- "download/0.8.1/meta_schema_202203130950.json",
+ "download/0.8.1/meta_schema_202203171008.json",
"wrong_meta": "wrong content",
},
f,
diff --git a/tests/test_bundle_verify_net.py b/tests/test_bundle_verify_net.py
new file mode 100644
index 0000000000..c6aa6d61fb
--- /dev/null
+++ b/tests/test_bundle_verify_net.py
@@ -0,0 +1,46 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import subprocess
+import sys
+import tempfile
+import unittest
+
+from parameterized import parameterized
+
+from monai.bundle import ConfigParser
+from tests.utils import skip_if_windows
+
+TEST_CASE_1 = [
+ os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json"),
+ os.path.join(os.path.dirname(__file__), "testing_data", "inference.json"),
+]
+
+
+@skip_if_windows
+class TestVerifyNetwork(unittest.TestCase):
+ @parameterized.expand([TEST_CASE_1])
+ def test_verify(self, meta_file, config_file):
+ with tempfile.TemporaryDirectory() as tempdir:
+ def_args = {"meta_file": "will be replaced by `meta_file` arg", "p": 2}
+ def_args_file = os.path.join(tempdir, "def_args.json")
+ ConfigParser.export_config_file(config=def_args, filepath=def_args_file)
+
+ cmd = [sys.executable, "-m", "monai.bundle", "verify_net_in_out", "network_def", "--meta_file", meta_file]
+ cmd += ["--config_file", config_file, "-n", "2", "--any", "32", "--args_file", def_args_file]
+ cmd += ["--_meta_#network_data_format#inputs#image#spatial_shape", "[32,'*','4**p*n']"]
+ ret = subprocess.check_call(cmd)
+ self.assertEqual(ret, 0)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/testing_data/inference.json b/tests/testing_data/inference.json
index ce229023fb..c97fe9172e 100644
--- a/tests/testing_data/inference.json
+++ b/tests/testing_data/inference.json
@@ -1,4 +1,5 @@
{
+ "dataset_dir": "/workspace/data/Task09_Spleen",
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
"network_def": {
"_target_": "UNet",
diff --git a/tests/testing_data/inference.yaml b/tests/testing_data/inference.yaml
index 0a3383adbb..c1fcd66a1c 100644
--- a/tests/testing_data/inference.yaml
+++ b/tests/testing_data/inference.yaml
@@ -1,4 +1,5 @@
---
+dataset_dir: "/workspace/data/Task09_Spleen"
device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
network_def:
_target_: UNet
diff --git a/tests/testing_data/metadata.json b/tests/testing_data/metadata.json
index 97bc218f5e..42a55b114c 100644
--- a/tests/testing_data/metadata.json
+++ b/tests/testing_data/metadata.json
@@ -1,5 +1,5 @@
{
- "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_202203130950.json",
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_202203171008.json",
"version": "0.1.0",
"changelog": {
"0.1.0": "complete the model package",
@@ -17,7 +17,6 @@
"copyright": "Copyright (c) MONAI Consortium",
"data_source": "Task09_Spleen.tar from http://medicaldecathlon.com/",
"data_type": "dicom",
- "dataset_dir": "/workspace/data/Task09_Spleen",
"image_classes": "single channel data, intensity scaled to [0, 1]",
"label_classes": "single channel data, 1 is spleen, 0 is everything else",
"pred_classes": "2 channels OneHot data, channel 1 is spleen, channel 0 is background",