diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst index 0283cb909e..87c4bf36d2 100644 --- a/docs/source/bundle.rst +++ b/docs/source/bundle.rst @@ -37,3 +37,4 @@ Model Bundle --------- .. autofunction:: run .. autofunction:: verify_metadata +.. autofunction:: verify_net_in_out diff --git a/docs/source/mb_specification.rst b/docs/source/mb_specification.rst index 53976f70db..d383dd7d8e 100644 --- a/docs/source/mb_specification.rst +++ b/docs/source/mb_specification.rst @@ -102,7 +102,6 @@ An example JSON metadata file: "copyright": "Copyright (c) MONAI Consortium", "data_source": "Task09_Spleen.tar from http://medicaldecathlon.com/", "data_type": "dicom", - "dataset_dir": "/workspace/data/Task09_Spleen", "image_classes": "single channel data, intensity scaled to [0, 1]", "label_classes": "single channel data, 1 is spleen, 0 is everything else", "pred_classes": "2 channels OneHot data, channel 1 is spleen, channel 0 is background", diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index 6f84800208..72c8805e9f 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -12,5 +12,5 @@ from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable from .config_parser import ConfigParser from .reference_resolver import ReferenceResolver -from .scripts import run, verify_metadata +from .scripts import run, verify_metadata, verify_net_in_out from .utils import EXPR_KEY, ID_REF_KEY, ID_SEP_KEY, MACRO_KEY diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py index 45cd89bfdd..0ff0a476ef 100644 --- a/monai/bundle/__main__.py +++ b/monai/bundle/__main__.py @@ -10,7 +10,7 @@ # limitations under the License. -from monai.bundle.scripts import run, verify_metadata +from monai.bundle.scripts import run, verify_metadata, verify_net_in_out if __name__ == "__main__": from monai.utils import optional_import diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 1f3165dee3..890d703090 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -9,14 +9,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ast import pprint import re -from typing import Dict, Optional, Sequence, Union +from typing import Dict, Optional, Sequence, Tuple, Union + +import torch from monai.apps.utils import download_url, get_logger from monai.bundle.config_parser import ConfigParser from monai.config import PathLike -from monai.utils import check_parent_dir, optional_import +from monai.utils import check_parent_dir, get_equivalent_dtype, optional_import validate, _ = optional_import("jsonschema", name="validate") ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError") @@ -51,13 +54,54 @@ def _update_args(args: Optional[Union[str, Dict]] = None, ignore_none: bool = Tr return args_ -def _log_input_summary(tag: str, args: Dict): - logger.info(f"\n--- input summary of monai.bundle.scripts.{tag} ---") +def _log_input_summary(tag, args: Dict): + logger.info(f"--- input summary of monai.bundle.scripts.{tag} ---") for name, val in args.items(): logger.info(f"> {name}: {pprint.pformat(val)}") logger.info("---\n\n") +def _get_var_names(expr: str): + """ + Parse the expression and discover what variables are present in it based on ast module. + + Args: + expr: source expression to parse. + + """ + tree = ast.parse(expr) + return [m.id for m in ast.walk(tree) if isinstance(m, ast.Name)] + + +def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int = 1, any: int = 1) -> Tuple: + """ + Get spatial shape for fake data according to the specified shape pattern. + It supports `int` number and `string` with formats like: "32", "32 * n", "32 ** p", "32 ** p *n". + + Args: + shape: specified pattern for the spatial shape. + p: power factor to generate fake data shape if dim of expected shape is "x**p", default to 1. + p: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1. + any: specified size to generate fake data shape if dim of expected shape is "*", default to 1. + + """ + ret = [] + for i in shape: + if isinstance(i, int): + ret.append(i) + elif isinstance(i, str): + if i == "*": + ret.append(any) + else: + for c in _get_var_names(i): + if c not in ["p", "n"]: + raise ValueError(f"only support variables 'p' and 'n' so far, but got: {c}.") + ret.append(eval(i, {"p": p, "n": n})) + else: + raise ValueError(f"spatial shape items must be int or string, but got: {type(i)} {i}.") + return tuple(ret) + + def run( runner_id: Optional[str] = None, meta_file: Optional[Union[str, Sequence[str]]] = None, @@ -94,8 +138,8 @@ def run( if it is a list of file paths, the content of them will be merged. config_file: filepath of the config file, if `None`, must be provided in `args_file`. if it is a list of file paths, the content of them will be merged. - args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`, - `runner_id` and override pairs. so that the command line inputs can be simplified. + args_file: a JSON or YAML file to provide default values for `runner_id`, `meta_file`, + `config_file`, and override pairs. so that the command line inputs can be simplified. override: id-value pairs to override or add the corresponding config content. e.g. ``--net#input_chns 42``. @@ -172,3 +216,95 @@ def verify_metadata( logger.info(re.compile(r".*Failed validating", re.S).findall(str(e))[0] + f" against schema `{url}`.") return logger.info("metadata is verified with no error.") + + +def verify_net_in_out( + net_id: Optional[str] = None, + meta_file: Optional[Union[str, Sequence[str]]] = None, + config_file: Optional[Union[str, Sequence[str]]] = None, + device: Optional[str] = None, + p: Optional[int] = None, + n: Optional[int] = None, + any: Optional[int] = None, + args_file: Optional[str] = None, + **override, +): + """ + Verify the input and output data shape and data type of network defined in the metadata. + Will test with fake Tensor data according to the required data shape in `metadata`. + + Typical usage examples: + + .. code-block:: bash + + python -m monai.bundle verify_net_in_out network --meta_file --config_file + + Args: + net_id: ID name of the network component to verify, it must be `torch.nn.Module`. + meta_file: filepath of the metadata file to get network args, if `None`, must be provided in `args_file`. + if it is a list of file paths, the content of them will be merged. + config_file: filepath of the config file to get network definition, if `None`, must be provided in `args_file`. + if it is a list of file paths, the content of them will be merged. + device: target device to run the network forward computation, if None, prefer to "cuda" if existing. + p: power factor to generate fake data shape if dim of expected shape is "x**p", default to 1. + p: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1. + any: specified size to generate fake data shape if dim of expected shape is "*", default to 1. + args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`, + `net_id` and override pairs. so that the command line inputs can be simplified. + override: id-value pairs to override or add the corresponding config content. + e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``. + + """ + + _args = _update_args( + args=args_file, + net_id=net_id, + meta_file=meta_file, + config_file=config_file, + device=device, + p=p, + n=n, + any=any, + **override, + ) + _log_input_summary(tag="verify_net_in_out", args=_args) + + parser = ConfigParser() + parser.read_config(f=_args.pop("config_file")) + parser.read_meta(f=_args.pop("meta_file")) + id = _args.pop("net_id", "") + device_ = torch.device(_args.pop("device", "cuda:0" if torch.cuda.is_available() else "cpu")) + p = _args.pop("p", 1) + n = _args.pop("n", 1) + any = _args.pop("any", 1) + + # the rest key-values in the _args are to override config content + for k, v in _args.items(): + parser[k] = v + + try: + key: str = id # mark the full id when KeyError + net = parser.get_parsed_content(key).to(device_) + key = "_meta_#network_data_format#inputs#image#num_channels" + input_channels = parser[key] + key = "_meta_#network_data_format#inputs#image#spatial_shape" + input_spatial_shape = tuple(parser[key]) + key = "_meta_#network_data_format#inputs#image#dtype" + input_dtype = get_equivalent_dtype(parser[key], torch.Tensor) + key = "_meta_#network_data_format#outputs#pred#num_channels" + output_channels = parser[key] + key = "_meta_#network_data_format#outputs#pred#dtype" + output_dtype = get_equivalent_dtype(parser[key], torch.Tensor) + except KeyError as e: + raise KeyError(f"Failed to verify due to missing expected key in the config: {key}.") from e + + net.eval() + with torch.no_grad(): + spatial_shape = _get_fake_spatial_shape(input_spatial_shape, p=p, n=n, any=any) # type: ignore + test_data = torch.rand(*(1, input_channels, *spatial_shape), dtype=input_dtype, device=device_) + output = net(test_data) + if output.shape[1] != output_channels: + raise ValueError(f"output channel number `{output.shape[1]}` doesn't match: `{output_channels}`.") + if output.dtype != output_dtype: + raise ValueError(f"dtype of output data `{output.dtype}` doesn't match: {output_dtype}.") + logger.info("data shape of network is verified with no error.") diff --git a/tests/min_tests.py b/tests/min_tests.py index 7d3b35be47..c0d4f36430 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -161,6 +161,7 @@ def run_testsuit(): "test_prepare_batch_default_dist", "test_parallel_execution_dist", "test_bundle_verify_metadata", + "test_bundle_verify_net", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_bundle_verify_metadata.py b/tests/test_bundle_verify_metadata.py index 7e2bd02209..bf96096c64 100644 --- a/tests/test_bundle_verify_metadata.py +++ b/tests/test_bundle_verify_metadata.py @@ -38,7 +38,7 @@ def test_verify(self, meta_file, schema_file): def_args_file = os.path.join(tempdir, "def_args.json") ConfigParser.export_config_file(config=def_args, filepath=def_args_file) - hash_val = "b11acc946148c0186924f8234562b947" + hash_val = "e3a7e23d1113a1f3e6c69f09b6f9ce2c" cmd = [sys.executable, "-m", "monai.bundle", "verify_metadata", "--meta_file", meta_file] cmd += ["--filepath", schema_file, "--hash_val", hash_val, "--args_file", def_args_file] @@ -54,7 +54,7 @@ def test_verify_error(self): json.dump( { "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/" - "download/0.8.1/meta_schema_202203130950.json", + "download/0.8.1/meta_schema_202203171008.json", "wrong_meta": "wrong content", }, f, diff --git a/tests/test_bundle_verify_net.py b/tests/test_bundle_verify_net.py new file mode 100644 index 0000000000..c6aa6d61fb --- /dev/null +++ b/tests/test_bundle_verify_net.py @@ -0,0 +1,46 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import sys +import tempfile +import unittest + +from parameterized import parameterized + +from monai.bundle import ConfigParser +from tests.utils import skip_if_windows + +TEST_CASE_1 = [ + os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json"), + os.path.join(os.path.dirname(__file__), "testing_data", "inference.json"), +] + + +@skip_if_windows +class TestVerifyNetwork(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + def test_verify(self, meta_file, config_file): + with tempfile.TemporaryDirectory() as tempdir: + def_args = {"meta_file": "will be replaced by `meta_file` arg", "p": 2} + def_args_file = os.path.join(tempdir, "def_args.json") + ConfigParser.export_config_file(config=def_args, filepath=def_args_file) + + cmd = [sys.executable, "-m", "monai.bundle", "verify_net_in_out", "network_def", "--meta_file", meta_file] + cmd += ["--config_file", config_file, "-n", "2", "--any", "32", "--args_file", def_args_file] + cmd += ["--_meta_#network_data_format#inputs#image#spatial_shape", "[32,'*','4**p*n']"] + ret = subprocess.check_call(cmd) + self.assertEqual(ret, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/testing_data/inference.json b/tests/testing_data/inference.json index ce229023fb..c97fe9172e 100644 --- a/tests/testing_data/inference.json +++ b/tests/testing_data/inference.json @@ -1,4 +1,5 @@ { + "dataset_dir": "/workspace/data/Task09_Spleen", "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "network_def": { "_target_": "UNet", diff --git a/tests/testing_data/inference.yaml b/tests/testing_data/inference.yaml index 0a3383adbb..c1fcd66a1c 100644 --- a/tests/testing_data/inference.yaml +++ b/tests/testing_data/inference.yaml @@ -1,4 +1,5 @@ --- +dataset_dir: "/workspace/data/Task09_Spleen" device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" network_def: _target_: UNet diff --git a/tests/testing_data/metadata.json b/tests/testing_data/metadata.json index 97bc218f5e..42a55b114c 100644 --- a/tests/testing_data/metadata.json +++ b/tests/testing_data/metadata.json @@ -1,5 +1,5 @@ { - "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_202203130950.json", + "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_202203171008.json", "version": "0.1.0", "changelog": { "0.1.0": "complete the model package", @@ -17,7 +17,6 @@ "copyright": "Copyright (c) MONAI Consortium", "data_source": "Task09_Spleen.tar from http://medicaldecathlon.com/", "data_type": "dicom", - "dataset_dir": "/workspace/data/Task09_Spleen", "image_classes": "single channel data, intensity scaled to [0, 1]", "label_classes": "single channel data, 1 is spleen, 0 is everything else", "pred_classes": "2 channels OneHot data, channel 1 is spleen, channel 0 is background",