From a9712a6b349aacc56ca4d73f19ea2b464b70f3fa Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 14 Mar 2022 12:55:48 +0800 Subject: [PATCH 01/13] [DLMED] add verify script Signed-off-by: Nic Ma --- monai/bundle/scripts.py | 160 +++++++++++++++++++++++++++++++++------- 1 file changed, 133 insertions(+), 27 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index ebfd3e54ac..5d2a24bd84 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -10,18 +10,24 @@ # limitations under the License. import pprint -from typing import Dict, Optional, Sequence, Union +from typing import Dict, Optional, Sequence, Tuple, Union +import torch +from monai.apps.utils import get_logger from monai.bundle.config_parser import ConfigParser +from monai.utils.type_conversion import get_equivalent_dtype +logger = get_logger(module_name=__name__) -def _update_default_args(args: Optional[Union[str, Dict]] = None, **kwargs) -> Dict: + +def _update_args(args: Optional[Union[str, Dict]] = None, ignore_none: bool = True, **kwargs) -> Dict: """ Update the `args` with the input `kwargs`. For dict data, recursively update the content based on the keys. Args: args: source args to update. + ignore_none: whether to ignore input args with None value, default to `True`. kwargs: destination args to update. """ @@ -32,14 +38,38 @@ def _update_default_args(args: Optional[Union[str, Dict]] = None, **kwargs) -> D # recursively update the default args with new args for k, v in kwargs.items(): - args_[k] = _update_default_args(args_[k], **v) if isinstance(v, dict) and isinstance(args_.get(k), dict) else v + if ignore_none and v is None: + continue + if isinstance(v, dict) and isinstance(args_.get(k), dict): + args_[k] = _update_args(args_[k], ignore_none, **v) + else: + args_[k] = v return args_ +def _log_input_summary(tag: str, args: Dict): + logger.info(f"\n--- input summary of monai.bundle.scripts.{tag} ---") + for name, val in args.items(): + logger.info(f"> {name}: {pprint.pformat(val)}") + logger.info("---\n\n") + + +def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int = 1, any: int = 1): + ret = [] + for i in shape: + if isinstance(i, int): + ret.append(i) + elif isinstance(i, str): + ret.append(any if i == "*" else eval(i, {"p": p, "n": n})) + else: + raise ValueError(f"spatial shape items must be int or string, but got: {type(i)} {i}.") + return tuple(ret) + + def run( + runner_id: Optional[str] = None, meta_file: Optional[Union[str, Sequence[str]]] = None, config_file: Optional[Union[str, Sequence[str]]] = None, - target_id: Optional[str] = None, args_file: Optional[str] = None, **override, ): @@ -51,58 +81,46 @@ def run( .. code-block:: bash # Execute this module as a CLI entry: - python -m monai.bundle run --meta_file --config_file --target_id trainer + python -m monai.bundle run trainer --meta_file --config_file # Override config values at runtime by specifying the component id and its new value: - python -m monai.bundle run --net#input_chns 1 ... + python -m monai.bundle run trainer --net#input_chns 1 ... # Override config values with another config file `/path/to/another.json`: - python -m monai.bundle run --net %/path/to/another.json ... + python -m monai.bundle run evaluator --net %/path/to/another.json ... # Override config values with part content of another config file: - python -m monai.bundle run --net %/data/other.json#net_arg ... + python -m monai.bundle run trainer --net %/data/other.json#net_arg ... # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line. # Other args still can override the default args at runtime: python -m monai.bundle run --args_file "/workspace/data/args.json" --config_file Args: + runner_id: ID name of the runner component or workflow, it must have a `run` method. meta_file: filepath of the metadata file, if `None`, must be provided in `args_file`. if it is a list of file paths, the content of them will be merged. config_file: filepath of the config file, if `None`, must be provided in `args_file`. if it is a list of file paths, the content of them will be merged. - target_id: ID name of the target component or workflow, it must have a `run` method. - args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`, - `target_id` and override pairs. so that the command line inputs can be simplified. + args_file: a JSON or YAML file to provide default values for `runner_id`, `meta_file`, + `config_file`, and override pairs. so that the command line inputs can be simplified. override: id-value pairs to override or add the corresponding config content. e.g. ``--net#input_chns 42``. """ - k_v = zip(["meta_file", "config_file", "target_id"], [meta_file, config_file, target_id]) - for k, v in k_v: - if v is not None: - override[k] = v - - full_kv = zip( - ("meta_file", "config_file", "target_id", "args_file", "override"), - (meta_file, config_file, target_id, args_file, override), - ) - print("\n--- input summary of monai.bundle.scripts.run ---") - for name, val in full_kv: - print(f"> {name}: {pprint.pformat(val)}") - print("---\n\n") - _args = _update_default_args(args=args_file, **override) + _args = _update_args(args=args_file, runner_id=runner_id, meta_file=meta_file, config_file=config_file, **override) for k in ("meta_file", "config_file"): if k not in _args: raise ValueError(f"{k} is required for 'monai.bundle run'.\n{run.__doc__}") + _log_input_summary(tag="run", args=_args) parser = ConfigParser() parser.read_config(f=_args.pop("config_file")) parser.read_meta(f=_args.pop("meta_file")) - id = _args.pop("target_id", "") + id = _args.pop("runner_id", "") - # the rest key-values in the args are to override config content + # the rest key-values in the _args are to override config content for k, v in _args.items(): parser[k] = v @@ -110,3 +128,91 @@ def run( if not hasattr(workflow, "run"): raise ValueError(f"The parsed workflow {type(workflow)} does not have a `run` method.\n{run.__doc__}") workflow.run() + + +def verify_net_in_out( + net_id: Optional[str] = None, + meta_file: Optional[Union[str, Sequence[str]]] = None, + config_file: Optional[Union[str, Sequence[str]]] = None, + device: Optional[str] = None, + p: Optional[int] = None, + n: Optional[int] = None, + any: Optional[int] = None, + args_file: Optional[str] = None, + **override, +): + """ + Verify the input and output shape of network defined in the metadata. + Will test with fake Tensor data according to the network args with id: + input data: + "_meta_#network_data_format#inputs#image#num_channels" + "_meta_#network_data_format#inputs#image#spatial_shape" + "_meta_#network_data_format#inputs#image#dtype" + "_meta_#network_data_format#inputs#image#value_range" + output data: + "_meta_#network_data_format#outputs#pred#num_channels" + "_meta_#network_data_format#outputs#pred#dtype" + "_meta_#network_data_format#outputs#pred#value_range" + + Args: + net_id: ID name of the network component to verify, it must be `torch.nn.Module`. + meta_file: filepath of the metadata file to get network args, if `None`, must be provided in `args_file`. + if it is a list of file paths, the content of them will be merged. + config_file: filepath of the config file to get network definition, if `None`, must be provided in `args_file`. + if it is a list of file paths, the content of them will be merged. + device: target device to run the network forward computation, if None, prefer to "cuda" if existing. + p: power factor to generate fake data shape if dim of expected shape is "x**p", default to 1. + p: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1. + any: specified size to generate fake data shape if dim of expected shape is "*", default to 1. + args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`, + `net_id` and override pairs. so that the command line inputs can be simplified. + override: id-value pairs to override or add the corresponding config content. + e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``. + + """ + + _args = _update_args( + args=args_file, + net_id=net_id, + meta_file=meta_file, + config_file=config_file, + device=device, + p=p, + n=n, + any=any, + **override, + ) + _log_input_summary(tag="verify_net_in_out", args=_args) + + parser = ConfigParser() + parser.read_config(f=_args.pop("config_file")) + parser.read_meta(f=_args.pop("meta_file")) + id = _args.pop("net_id", "") + device = torch.device(_args.pop("device", "cuda" if torch.cuda.is_available() else "cpu")) + p = _args.pop("p", 1) + n = _args.pop("n", 1) + any = _args.pop("any", 1) + + # the rest key-values in the _args are to override config content + for k, v in _args.items(): + parser[k] = v + + net = parser[id].to(device) + input_channels = parser["_meta_#network_data_format#inputs#image#num_channels"] + input_spatial_shape = tuple(parser["_meta_#network_data_format#inputs#image#spatial_shape"]) + input_dtype = get_equivalent_dtype(parser["_meta_#network_data_format#inputs#image#dtype"], data_type=torch.Tensor) + input_value_range = tuple(parser["_meta_#network_data_format#inputs#image#value_range"]) + + output_channels = parser["_meta_#network_data_format#outputs#pred#num_channels"] + output_dtype = get_equivalent_dtype(parser["_meta_#network_data_format#output#pred#dtype"], data_type=torch.Tensor) + output_value_range = tuple(parser["_meta_#network_data_format#output#pred#value_range"]) + + net.eval() + with torch.no_grad(): + spatial_shape = _get_fake_spatial_shape(input_spatial_shape, p=p, n=n, any=any) + test_data = torch.rand(*(input_channels, *spatial_shape), dtype=input_dtype, device=device) + output = net(test_data) + if output.shape[0] != output_channels: + raise ValueError(f"output channel number `{output.shape[0]}` doesn't match: `{output_channels}`.") + if output.dtype != output_dtype: + raise ValueError(f"dtype of output data `{output.dtype}` doesn't match: {output_dtype}.") From b2dcac77defbcfee8f0fb9b88ede25d0b943fb59 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 14 Mar 2022 13:57:35 +0800 Subject: [PATCH 02/13] [DLMED] fix typo Signed-off-by: Nic Ma --- monai/bundle/scripts.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 5d2a24bd84..a4f604bbc2 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -54,7 +54,18 @@ def _log_input_summary(tag: str, args: Dict): logger.info("---\n\n") -def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int = 1, any: int = 1): +def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int = 1, any: int = 1) -> Tuple[int]: + """ + Get spatial shape for fake data according to the specified shape pattern. + It supports `int` number and `string` with formats like: "32", "32 * n", "32 ** p", "32 ** p *n". + + Args: + shape: specified pattern for the spatial shape. + p: power factor to generate fake data shape if dim of expected shape is "x**p", default to 1. + p: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1. + any: specified size to generate fake data shape if dim of expected shape is "*", default to 1. + + """ ret = [] for i in shape: if isinstance(i, int): @@ -142,17 +153,15 @@ def verify_net_in_out( **override, ): """ - Verify the input and output shape of network defined in the metadata. + Verify the input and output data shape and data type of network defined in the metadata. Will test with fake Tensor data according to the network args with id: input data: "_meta_#network_data_format#inputs#image#num_channels" "_meta_#network_data_format#inputs#image#spatial_shape" "_meta_#network_data_format#inputs#image#dtype" - "_meta_#network_data_format#inputs#image#value_range" output data: "_meta_#network_data_format#outputs#pred#num_channels" "_meta_#network_data_format#outputs#pred#dtype" - "_meta_#network_data_format#outputs#pred#value_range" Args: net_id: ID name of the network component to verify, it must be `torch.nn.Module`. @@ -201,11 +210,9 @@ def verify_net_in_out( input_channels = parser["_meta_#network_data_format#inputs#image#num_channels"] input_spatial_shape = tuple(parser["_meta_#network_data_format#inputs#image#spatial_shape"]) input_dtype = get_equivalent_dtype(parser["_meta_#network_data_format#inputs#image#dtype"], data_type=torch.Tensor) - input_value_range = tuple(parser["_meta_#network_data_format#inputs#image#value_range"]) output_channels = parser["_meta_#network_data_format#outputs#pred#num_channels"] output_dtype = get_equivalent_dtype(parser["_meta_#network_data_format#output#pred#dtype"], data_type=torch.Tensor) - output_value_range = tuple(parser["_meta_#network_data_format#output#pred#value_range"]) net.eval() with torch.no_grad(): From 2aa1b3f80e3b87af3a67ebc32acf488ecb10564f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 15 Mar 2022 18:57:07 +0800 Subject: [PATCH 03/13] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/bundle/scripts.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index cec05bedcb..921f45e800 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ast import pprint import re from typing import Dict, Optional, Sequence, Tuple, Union @@ -53,12 +54,32 @@ def _update_args(args: Optional[Union[str, Dict]] = None, ignore_none: bool = Tr def _log_input_summary(tag: str, args: Dict): + """ + Log the arguments of bundle scripts. + + Args: + tag: tag to identify the script in the log. + args: arguments of the script to log. + + """ logger.info(f"\n--- input summary of monai.bundle.scripts.{tag} ---") for name, val in args.items(): logger.info(f"> {name}: {pprint.pformat(val)}") logger.info("---\n\n") +def _get_var_names(expr: str): + """ + Parse the expression and discover what variables are present in it based on ast module. + + Args: + expr: source expression to parse. + + """ + tree = ast.parse(expr) + return [m.id for m in ast.walk(tree) if isinstance(m, ast.Name)] + + def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int = 1, any: int = 1) -> Tuple[int]: """ Get spatial shape for fake data according to the specified shape pattern. @@ -76,7 +97,13 @@ def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int if isinstance(i, int): ret.append(i) elif isinstance(i, str): - ret.append(any if i == "*" else eval(i, {"p": p, "n": n})) + if i == "*": + ret.append(any) + else: + for c in _get_var_names(i): + if c not in ["p", "n"]: + raise ValueError(f"only support variables 'm' and 'p' so far, but got: {c}.") + eval(i, {"p": p, "n": n}) else: raise ValueError(f"spatial shape items must be int or string, but got: {type(i)} {i}.") return tuple(ret) @@ -254,7 +281,7 @@ def verify_net_in_out( parser.read_config(f=_args.pop("config_file")) parser.read_meta(f=_args.pop("meta_file")) id = _args.pop("net_id", "") - device = torch.device(_args.pop("device", "cuda:0" if torch.cuda.is_available() else "cpu:0")) + device = torch.device(_args.pop("device", "cuda" if torch.cuda.is_available() else "cpu")) p = _args.pop("p", 1) n = _args.pop("n", 1) any = _args.pop("any", 1) From 6292b524aadc1a84a2f4c90c85bcf5c685fe014c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 15 Mar 2022 20:20:16 +0800 Subject: [PATCH 04/13] [DLMED] add unit tests and doc Signed-off-by: Nic Ma --- docs/source/bundle.rst | 1 + monai/bundle/__init__.py | 2 +- monai/bundle/__main__.py | 2 +- monai/bundle/scripts.py | 14 +++++---- tests/test_bundle_verify_net.py | 48 +++++++++++++++++++++++++++++++ tests/testing_data/inference.json | 2 +- tests/testing_data/inference.yaml | 2 +- 7 files changed, 61 insertions(+), 10 deletions(-) create mode 100644 tests/test_bundle_verify_net.py diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst index 0283cb909e..87c4bf36d2 100644 --- a/docs/source/bundle.rst +++ b/docs/source/bundle.rst @@ -37,3 +37,4 @@ Model Bundle --------- .. autofunction:: run .. autofunction:: verify_metadata +.. autofunction:: verify_net_in_out diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index 6f84800208..72c8805e9f 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -12,5 +12,5 @@ from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable from .config_parser import ConfigParser from .reference_resolver import ReferenceResolver -from .scripts import run, verify_metadata +from .scripts import run, verify_metadata, verify_net_in_out from .utils import EXPR_KEY, ID_REF_KEY, ID_SEP_KEY, MACRO_KEY diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py index 45cd89bfdd..0ff0a476ef 100644 --- a/monai/bundle/__main__.py +++ b/monai/bundle/__main__.py @@ -10,7 +10,7 @@ # limitations under the License. -from monai.bundle.scripts import run, verify_metadata +from monai.bundle.scripts import run, verify_metadata, verify_net_in_out if __name__ == "__main__": from monai.utils import optional_import diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 921f45e800..1f12ec40c4 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -15,6 +15,7 @@ from typing import Dict, Optional, Sequence, Tuple, Union import torch + from monai.apps.utils import download_url, get_logger from monai.bundle.config_parser import ConfigParser from monai.config import PathLike @@ -103,7 +104,7 @@ def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int for c in _get_var_names(i): if c not in ["p", "n"]: raise ValueError(f"only support variables 'm' and 'p' so far, but got: {c}.") - eval(i, {"p": p, "n": n}) + ret.append(eval(i, {"p": p, "n": n})) else: raise ValueError(f"spatial shape items must be int or string, but got: {type(i)} {i}.") return tuple(ret) @@ -290,20 +291,21 @@ def verify_net_in_out( for k, v in _args.items(): parser[k] = v - net = parser[id].to(device) + net = parser.get_parsed_content(id).to(device) input_channels = parser["_meta_#network_data_format#inputs#image#num_channels"] input_spatial_shape = tuple(parser["_meta_#network_data_format#inputs#image#spatial_shape"]) input_dtype = get_equivalent_dtype(parser["_meta_#network_data_format#inputs#image#dtype"], data_type=torch.Tensor) output_channels = parser["_meta_#network_data_format#outputs#pred#num_channels"] - output_dtype = get_equivalent_dtype(parser["_meta_#network_data_format#output#pred#dtype"], data_type=torch.Tensor) + output_dtype = get_equivalent_dtype(parser["_meta_#network_data_format#outputs#pred#dtype"], data_type=torch.Tensor) net.eval() with torch.no_grad(): spatial_shape = _get_fake_spatial_shape(input_spatial_shape, p=p, n=n, any=any) - test_data = torch.rand(*(input_channels, *spatial_shape), dtype=input_dtype, device=device) + test_data = torch.rand(*(1, input_channels, *spatial_shape), dtype=input_dtype, device=device) output = net(test_data) - if output.shape[0] != output_channels: - raise ValueError(f"output channel number `{output.shape[0]}` doesn't match: `{output_channels}`.") + if output.shape[1] != output_channels: + raise ValueError(f"output channel number `{output.shape[1]}` doesn't match: `{output_channels}`.") if output.dtype != output_dtype: raise ValueError(f"dtype of output data `{output.dtype}` doesn't match: {output_dtype}.") + logger.info("data shape of network is verified with no error.") diff --git a/tests/test_bundle_verify_net.py b/tests/test_bundle_verify_net.py new file mode 100644 index 0000000000..15ee0e2827 --- /dev/null +++ b/tests/test_bundle_verify_net.py @@ -0,0 +1,48 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import subprocess +import sys +import tempfile +import unittest + +from parameterized import parameterized + +from monai.bundle import ConfigParser +from tests.utils import skip_if_windows + +TEST_CASE_1 = [ + os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json"), + os.path.join(os.path.dirname(__file__), "testing_data", "inference.json"), +] + + +@skip_if_windows +class TestVerifyNetwork(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + def test_verify(self, meta_file, config_file): + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + with tempfile.TemporaryDirectory() as tempdir: + def_args = {"meta_file": "will be replaced by `meta_file` arg", "p": 2} + def_args_file = os.path.join(tempdir, "def_args.json") + ConfigParser.export_config_file(config=def_args, filepath=def_args_file) + + cmd = [sys.executable, "-m", "monai.bundle", "verify_net_in_out", "network_def", "--meta_file", meta_file] + cmd += ["--config_file", config_file, "-n", "2", "--any", "32", "--args_file", def_args_file] + cmd += ["--_meta_#network_data_format#inputs#image#spatial_shape", "[32,'*','4**p*n']"] + ret = subprocess.check_call(cmd) + self.assertEqual(ret, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/testing_data/inference.json b/tests/testing_data/inference.json index 6cc6de88ef..b96968496d 100644 --- a/tests/testing_data/inference.json +++ b/tests/testing_data/inference.json @@ -1,5 +1,5 @@ { - "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", + "device": "$torch.device('cuda' if torch.cuda.is_available() else 'cpu')", "network_def": { "_target_": "UNet", "spatial_dims": 3, diff --git a/tests/testing_data/inference.yaml b/tests/testing_data/inference.yaml index eb2870ee03..58eeca8191 100644 --- a/tests/testing_data/inference.yaml +++ b/tests/testing_data/inference.yaml @@ -1,5 +1,5 @@ --- -device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" +device: "$torch.device('cuda' if torch.cuda.is_available() else 'cpu')" network_def: _target_: UNet spatial_dims: 3 From 6ceb6a02fb4c804107a77ff290c5b4d7723bb681 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 15 Mar 2022 20:41:59 +0800 Subject: [PATCH 05/13] [DLMED] fix flake8 Signed-off-by: Nic Ma --- monai/bundle/scripts.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 1f12ec40c4..3705d731c3 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -81,7 +81,7 @@ def _get_var_names(expr: str): return [m.id for m in ast.walk(tree) if isinstance(m, ast.Name)] -def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int = 1, any: int = 1) -> Tuple[int]: +def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int = 1, any: int = 1) -> Tuple: """ Get spatial shape for fake data according to the specified shape pattern. It supports `int` number and `string` with formats like: "32", "32 * n", "32 ** p", "32 ** p *n". @@ -282,7 +282,7 @@ def verify_net_in_out( parser.read_config(f=_args.pop("config_file")) parser.read_meta(f=_args.pop("meta_file")) id = _args.pop("net_id", "") - device = torch.device(_args.pop("device", "cuda" if torch.cuda.is_available() else "cpu")) + device_ = torch.device(_args.pop("device", "cuda" if torch.cuda.is_available() else "cpu")) p = _args.pop("p", 1) n = _args.pop("n", 1) any = _args.pop("any", 1) @@ -291,7 +291,7 @@ def verify_net_in_out( for k, v in _args.items(): parser[k] = v - net = parser.get_parsed_content(id).to(device) + net = parser.get_parsed_content(id).to(device_) input_channels = parser["_meta_#network_data_format#inputs#image#num_channels"] input_spatial_shape = tuple(parser["_meta_#network_data_format#inputs#image#spatial_shape"]) input_dtype = get_equivalent_dtype(parser["_meta_#network_data_format#inputs#image#dtype"], data_type=torch.Tensor) @@ -301,8 +301,8 @@ def verify_net_in_out( net.eval() with torch.no_grad(): - spatial_shape = _get_fake_spatial_shape(input_spatial_shape, p=p, n=n, any=any) - test_data = torch.rand(*(1, input_channels, *spatial_shape), dtype=input_dtype, device=device) + spatial_shape = _get_fake_spatial_shape(input_spatial_shape, p=p, n=n, any=any) # type: ignore + test_data = torch.rand(*(1, input_channels, *spatial_shape), dtype=input_dtype, device=device_) output = net(test_data) if output.shape[1] != output_channels: raise ValueError(f"output channel number `{output.shape[1]}` doesn't match: `{output_channels}`.") From 939b081079736771a392ce789a6bfcddb635157b Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 15 Mar 2022 21:20:50 +0800 Subject: [PATCH 06/13] [DLMED] skip min tests Signed-off-by: Nic Ma --- tests/min_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/min_tests.py b/tests/min_tests.py index bb47403090..6a86565aa6 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -161,6 +161,7 @@ def run_testsuit(): "test_parallel_execution_dist", "test_bundle_run", "test_bundle_verify_metadata", + "test_bundle_verify_net", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From 3e8b4c32114da9c159d9bbe30be86226cc9a729e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 16 Mar 2022 17:53:17 +0800 Subject: [PATCH 07/13] [DLMED] remove doc-string Signed-off-by: Nic Ma --- monai/bundle/scripts.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 3705d731c3..3813c7c47f 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -54,15 +54,7 @@ def _update_args(args: Optional[Union[str, Dict]] = None, ignore_none: bool = Tr return args_ -def _log_input_summary(tag: str, args: Dict): - """ - Log the arguments of bundle scripts. - - Args: - tag: tag to identify the script in the log. - args: arguments of the script to log. - - """ +def _log_input_summary(tag, args: Dict): logger.info(f"\n--- input summary of monai.bundle.scripts.{tag} ---") for name, val in args.items(): logger.info(f"> {name}: {pprint.pformat(val)}") From 2084db107c283b0932feb49bffb1b747e9f81041 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 16 Mar 2022 17:58:18 +0800 Subject: [PATCH 08/13] [DLMED] fix typo Signed-off-by: Nic Ma --- monai/bundle/scripts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 3813c7c47f..71252b6e05 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -55,7 +55,7 @@ def _update_args(args: Optional[Union[str, Dict]] = None, ignore_none: bool = Tr def _log_input_summary(tag, args: Dict): - logger.info(f"\n--- input summary of monai.bundle.scripts.{tag} ---") + logger.info(f"--- input summary of monai.bundle.scripts.{tag} ---") for name, val in args.items(): logger.info(f"> {name}: {pprint.pformat(val)}") logger.info("---\n\n") @@ -95,7 +95,7 @@ def _get_fake_spatial_shape(shape: Sequence[Union[str, int]], p: int = 1, n: int else: for c in _get_var_names(i): if c not in ["p", "n"]: - raise ValueError(f"only support variables 'm' and 'p' so far, but got: {c}.") + raise ValueError(f"only support variables 'p' and 'n' so far, but got: {c}.") ret.append(eval(i, {"p": p, "n": n})) else: raise ValueError(f"spatial shape items must be int or string, but got: {type(i)} {i}.") From 0f250bc9dd8ff0ee211eb9ad356376d0ab1b4ec2 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 16 Mar 2022 18:02:14 +0800 Subject: [PATCH 09/13] [DLMED] update device names Signed-off-by: Nic Ma --- monai/bundle/scripts.py | 2 +- tests/testing_data/inference.json | 2 +- tests/testing_data/inference.yaml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 71252b6e05..0dc8d0a1de 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -274,7 +274,7 @@ def verify_net_in_out( parser.read_config(f=_args.pop("config_file")) parser.read_meta(f=_args.pop("meta_file")) id = _args.pop("net_id", "") - device_ = torch.device(_args.pop("device", "cuda" if torch.cuda.is_available() else "cpu")) + device_ = torch.device(_args.pop("device", "cuda:0" if torch.cuda.is_available() else "cpu:0")) p = _args.pop("p", 1) n = _args.pop("n", 1) any = _args.pop("any", 1) diff --git a/tests/testing_data/inference.json b/tests/testing_data/inference.json index e0c823ec8f..f4230a4645 100644 --- a/tests/testing_data/inference.json +++ b/tests/testing_data/inference.json @@ -1,5 +1,5 @@ { - "device": "$torch.device('cuda' if torch.cuda.is_available() else 'cpu')", + "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu:0')", "network_def": { "_target_": "UNet", "spatial_dims": 3, diff --git a/tests/testing_data/inference.yaml b/tests/testing_data/inference.yaml index 7ede23c2a8..8bae488762 100644 --- a/tests/testing_data/inference.yaml +++ b/tests/testing_data/inference.yaml @@ -1,5 +1,5 @@ --- -device: "$torch.device('cuda' if torch.cuda.is_available() else 'cpu')" +device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu:0')" network_def: _target_: UNet spatial_dims: 3 From 6ddc1b047285a0280943d3820235ecddc280209f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 16 Mar 2022 18:35:12 +0800 Subject: [PATCH 10/13] [DLMED] update doc-string examples Signed-off-by: Nic Ma --- monai/bundle/scripts.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 0dc8d0a1de..a33522b37a 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -231,14 +231,13 @@ def verify_net_in_out( ): """ Verify the input and output data shape and data type of network defined in the metadata. - Will test with fake Tensor data according to the network args with id: - input data: - "_meta_#network_data_format#inputs#image#num_channels" - "_meta_#network_data_format#inputs#image#spatial_shape" - "_meta_#network_data_format#inputs#image#dtype" - output data: - "_meta_#network_data_format#outputs#pred#num_channels" - "_meta_#network_data_format#outputs#pred#dtype" + Will test with fake Tensor data according to the required data shape in `metadata`. + + Typical usage examples: + + .. code-block:: bash + + python -m monai.bundle verify_net_in_out network --meta_file --config_file Args: net_id: ID name of the network component to verify, it must be `torch.nn.Module`. From 793655912c857acce039b6dd3d5a5ee9817fddd0 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 16 Mar 2022 20:24:30 +0800 Subject: [PATCH 11/13] [DLMED] enhance error message Signed-off-by: Nic Ma --- monai/bundle/scripts.py | 22 +++++++++++++++------- tests/test_bundle_verify_net.py | 2 -- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index a33522b37a..834cb5f1ba 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -282,13 +282,21 @@ def verify_net_in_out( for k, v in _args.items(): parser[k] = v - net = parser.get_parsed_content(id).to(device_) - input_channels = parser["_meta_#network_data_format#inputs#image#num_channels"] - input_spatial_shape = tuple(parser["_meta_#network_data_format#inputs#image#spatial_shape"]) - input_dtype = get_equivalent_dtype(parser["_meta_#network_data_format#inputs#image#dtype"], data_type=torch.Tensor) - - output_channels = parser["_meta_#network_data_format#outputs#pred#num_channels"] - output_dtype = get_equivalent_dtype(parser["_meta_#network_data_format#outputs#pred#dtype"], data_type=torch.Tensor) + try: + key: str = id # mark the full id when KeyError + net = parser.get_parsed_content(key).to(device_) + key = "_meta_#network_data_format#inputs#image#num_channels" + input_channels = parser[key] + key = "_meta_#network_data_format#inputs#image#spatial_shape" + input_spatial_shape = tuple(parser[key]) + key = "_meta_#network_data_format#inputs#image#dtype" + input_dtype = get_equivalent_dtype(parser[key], torch.Tensor) + key = "_meta_#network_data_format#outputs#pred#num_channels" + output_channels = parser[key] + key = "_meta_#network_data_format#outputs#pred#dtype" + output_dtype = get_equivalent_dtype(parser[key], torch.Tensor) + except KeyError as e: + raise KeyError(f"Failed to verify due to missing expected key in the config: {key}.") from e net.eval() with torch.no_grad(): diff --git a/tests/test_bundle_verify_net.py b/tests/test_bundle_verify_net.py index 15ee0e2827..c6aa6d61fb 100644 --- a/tests/test_bundle_verify_net.py +++ b/tests/test_bundle_verify_net.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import os import subprocess import sys @@ -31,7 +30,6 @@ class TestVerifyNetwork(unittest.TestCase): @parameterized.expand([TEST_CASE_1]) def test_verify(self, meta_file, config_file): - logging.basicConfig(stream=sys.stdout, level=logging.INFO) with tempfile.TemporaryDirectory() as tempdir: def_args = {"meta_file": "will be replaced by `meta_file` arg", "p": 2} def_args_file = os.path.join(tempdir, "def_args.json") From b2947eadd399770e5c988704554821bc47eb8ff9 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 16 Mar 2022 23:32:15 +0800 Subject: [PATCH 12/13] [DLMED] cpu:0 to cpu Signed-off-by: Nic Ma --- monai/bundle/scripts.py | 2 +- tests/testing_data/inference.json | 2 +- tests/testing_data/inference.yaml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 834cb5f1ba..890d703090 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -273,7 +273,7 @@ def verify_net_in_out( parser.read_config(f=_args.pop("config_file")) parser.read_meta(f=_args.pop("meta_file")) id = _args.pop("net_id", "") - device_ = torch.device(_args.pop("device", "cuda:0" if torch.cuda.is_available() else "cpu:0")) + device_ = torch.device(_args.pop("device", "cuda:0" if torch.cuda.is_available() else "cpu")) p = _args.pop("p", 1) n = _args.pop("n", 1) any = _args.pop("any", 1) diff --git a/tests/testing_data/inference.json b/tests/testing_data/inference.json index f4230a4645..ce229023fb 100644 --- a/tests/testing_data/inference.json +++ b/tests/testing_data/inference.json @@ -1,5 +1,5 @@ { - "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu:0')", + "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "network_def": { "_target_": "UNet", "spatial_dims": 3, diff --git a/tests/testing_data/inference.yaml b/tests/testing_data/inference.yaml index 8bae488762..0a3383adbb 100644 --- a/tests/testing_data/inference.yaml +++ b/tests/testing_data/inference.yaml @@ -1,5 +1,5 @@ --- -device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu:0')" +device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" network_def: _target_: UNet spatial_dims: 3 From 7f4055a0c23a0159e9e01a76c576a7591b81b070 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 17 Mar 2022 11:30:33 +0800 Subject: [PATCH 13/13] [DLMED] adjust "dataset_dir" Signed-off-by: Nic Ma --- docs/source/mb_specification.rst | 1 - tests/test_bundle_verify_metadata.py | 4 ++-- tests/testing_data/inference.json | 1 + tests/testing_data/inference.yaml | 1 + tests/testing_data/metadata.json | 3 +-- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/mb_specification.rst b/docs/source/mb_specification.rst index 53976f70db..d383dd7d8e 100644 --- a/docs/source/mb_specification.rst +++ b/docs/source/mb_specification.rst @@ -102,7 +102,6 @@ An example JSON metadata file: "copyright": "Copyright (c) MONAI Consortium", "data_source": "Task09_Spleen.tar from http://medicaldecathlon.com/", "data_type": "dicom", - "dataset_dir": "/workspace/data/Task09_Spleen", "image_classes": "single channel data, intensity scaled to [0, 1]", "label_classes": "single channel data, 1 is spleen, 0 is everything else", "pred_classes": "2 channels OneHot data, channel 1 is spleen, channel 0 is background", diff --git a/tests/test_bundle_verify_metadata.py b/tests/test_bundle_verify_metadata.py index 7e2bd02209..bf96096c64 100644 --- a/tests/test_bundle_verify_metadata.py +++ b/tests/test_bundle_verify_metadata.py @@ -38,7 +38,7 @@ def test_verify(self, meta_file, schema_file): def_args_file = os.path.join(tempdir, "def_args.json") ConfigParser.export_config_file(config=def_args, filepath=def_args_file) - hash_val = "b11acc946148c0186924f8234562b947" + hash_val = "e3a7e23d1113a1f3e6c69f09b6f9ce2c" cmd = [sys.executable, "-m", "monai.bundle", "verify_metadata", "--meta_file", meta_file] cmd += ["--filepath", schema_file, "--hash_val", hash_val, "--args_file", def_args_file] @@ -54,7 +54,7 @@ def test_verify_error(self): json.dump( { "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/" - "download/0.8.1/meta_schema_202203130950.json", + "download/0.8.1/meta_schema_202203171008.json", "wrong_meta": "wrong content", }, f, diff --git a/tests/testing_data/inference.json b/tests/testing_data/inference.json index ce229023fb..c97fe9172e 100644 --- a/tests/testing_data/inference.json +++ b/tests/testing_data/inference.json @@ -1,4 +1,5 @@ { + "dataset_dir": "/workspace/data/Task09_Spleen", "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "network_def": { "_target_": "UNet", diff --git a/tests/testing_data/inference.yaml b/tests/testing_data/inference.yaml index 0a3383adbb..c1fcd66a1c 100644 --- a/tests/testing_data/inference.yaml +++ b/tests/testing_data/inference.yaml @@ -1,4 +1,5 @@ --- +dataset_dir: "/workspace/data/Task09_Spleen" device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" network_def: _target_: UNet diff --git a/tests/testing_data/metadata.json b/tests/testing_data/metadata.json index 97bc218f5e..42a55b114c 100644 --- a/tests/testing_data/metadata.json +++ b/tests/testing_data/metadata.json @@ -1,5 +1,5 @@ { - "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_202203130950.json", + "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_202203171008.json", "version": "0.1.0", "changelog": { "0.1.0": "complete the model package", @@ -17,7 +17,6 @@ "copyright": "Copyright (c) MONAI Consortium", "data_source": "Task09_Spleen.tar from http://medicaldecathlon.com/", "data_type": "dicom", - "dataset_dir": "/workspace/data/Task09_Spleen", "image_classes": "single channel data, intensity scaled to [0, 1]", "label_classes": "single channel data, 1 is spleen, 0 is everything else", "pred_classes": "2 channels OneHot data, channel 1 is spleen, channel 0 is background",