diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index fea46aafa3..6a22db1076 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -16,7 +16,9 @@ from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple import numpy as np +import torch +from monai import transforms from monai.config import KeysCollection from monai.utils import MAX_SEED, ensure_tuple @@ -45,6 +47,27 @@ def apply_transform(transform: Callable, data, map_items: bool = True): return [transform(item) for item in data] return transform(data) except Exception as e: + + if not isinstance(transform, transforms.compose.Compose): + # log the input data information of exact transform in the transform chain + datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False) + datastats._logger.info("input data information of the runtime error transform:") + if isinstance(data, (list, tuple)): + data = data[0] + + def _log_stats(data, prefix: Optional[str] = "Data"): + if isinstance(data, (np.ndarray, torch.Tensor)): + # log data type, shape, range for array + datastats(img=data, data_shape=True, value_range=True, prefix=prefix) # type: ignore + else: + # log data type and value for other meta data + datastats(img=data, data_value=True, prefix=prefix) + + if isinstance(data, dict): + for k, v in data.items(): + _log_stats(data=v, prefix=k) + else: + _log_stats(data=data) raise RuntimeError(f"applying transform {transform}") from e diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 41804d5c1d..f169002596 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -394,6 +394,7 @@ class DataStats(Transform): def __init__( self, prefix: str = "Data", + data_type: bool = True, data_shape: bool = True, value_range: bool = True, data_value: bool = False, @@ -403,6 +404,7 @@ def __init__( """ Args: prefix: will be printed in format: "{prefix} statistics". + data_type: whether to show the type of input data. data_shape: whether to show the shape of input data. value_range: whether to show the value range of input data. data_value: whether to show the raw value of input data. @@ -419,6 +421,7 @@ def __init__( if not isinstance(prefix, str): raise AssertionError("prefix must be a string.") self.prefix = prefix + self.data_type = data_type self.data_shape = data_shape self.value_range = value_range self.data_value = data_value @@ -438,6 +441,7 @@ def __call__( self, img: NdarrayTensor, prefix: Optional[str] = None, + data_type: Optional[bool] = None, data_shape: Optional[bool] = None, value_range: Optional[bool] = None, data_value: Optional[bool] = None, @@ -448,6 +452,8 @@ def __call__( """ lines = [f"{prefix or self.prefix} statistics:"] + if self.data_type if data_type is None else data_type: + lines.append(f"Type: {type(img)}") if self.data_shape if data_shape is None else data_shape: lines.append(f"Shape: {img.shape}") if self.value_range if value_range is None else value_range: diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index a05a5fc904..324835a874 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -507,6 +507,7 @@ def __init__( self, keys: KeysCollection, prefix: Union[Sequence[str], str] = "Data", + data_type: Union[Sequence[bool], bool] = True, data_shape: Union[Sequence[bool], bool] = True, value_range: Union[Sequence[bool], bool] = True, data_value: Union[Sequence[bool], bool] = False, @@ -520,6 +521,8 @@ def __init__( See also: :py:class:`monai.transforms.compose.MapTransform` prefix: will be printed in format: "{prefix} statistics". it also can be a sequence of string, each element corresponds to a key in ``keys``. + data_type: whether to show the type of input data. + it also can be a sequence of bool, each element corresponds to a key in ``keys``. data_shape: whether to show the shape of input data. it also can be a sequence of bool, each element corresponds to a key in ``keys``. value_range: whether to show the value range of input data. @@ -538,6 +541,7 @@ def __init__( """ super().__init__(keys, allow_missing_keys) self.prefix = ensure_tuple_rep(prefix, len(self.keys)) + self.data_type = ensure_tuple_rep(data_type, len(self.keys)) self.data_shape = ensure_tuple_rep(data_shape, len(self.keys)) self.value_range = ensure_tuple_rep(value_range, len(self.keys)) self.data_value = ensure_tuple_rep(data_value, len(self.keys)) @@ -547,12 +551,13 @@ def __init__( def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: d = dict(data) - for key, prefix, data_shape, value_range, data_value, additional_info in self.key_iterator( - d, self.prefix, self.data_shape, self.value_range, self.data_value, self.additional_info + for key, prefix, data_type, data_shape, value_range, data_value, additional_info in self.key_iterator( + d, self.prefix, self.data_type, self.data_shape, self.value_range, self.data_value, self.additional_info ): d[key] = self.printer( d[key], prefix, + data_type, data_shape, value_range, data_value, diff --git a/tests/test_data_stats.py b/tests/test_data_stats.py index 877da52263..073620ad43 100644 --- a/tests/test_data_stats.py +++ b/tests/test_data_stats.py @@ -23,6 +23,7 @@ TEST_CASE_1 = [ { "prefix": "test data", + "data_type": False, "data_shape": False, "value_range": False, "data_value": False, @@ -36,58 +37,80 @@ TEST_CASE_2 = [ { "prefix": "test data", - "data_shape": True, + "data_type": True, + "data_shape": False, "value_range": False, "data_value": False, "additional_info": None, "logger_handler": None, }, np.array([[0, 1], [1, 2]]), - "test data statistics:\nShape: (2, 2)", + "test data statistics:\nType: ", ] TEST_CASE_3 = [ { "prefix": "test data", + "data_type": True, "data_shape": True, - "value_range": True, + "value_range": False, "data_value": False, "additional_info": None, "logger_handler": None, }, np.array([[0, 1], [1, 2]]), - "test data statistics:\nShape: (2, 2)\nValue range: (0, 2)", + "test data statistics:\nType: \nShape: (2, 2)", ] TEST_CASE_4 = [ { "prefix": "test data", + "data_type": True, "data_shape": True, "value_range": True, - "data_value": True, + "data_value": False, "additional_info": None, "logger_handler": None, }, np.array([[0, 1], [1, 2]]), - "test data statistics:\nShape: (2, 2)\nValue range: (0, 2)\nValue: [[0 1]\n [1 2]]", + "test data statistics:\nType: \nShape: (2, 2)\nValue range: (0, 2)", ] TEST_CASE_5 = [ { "prefix": "test data", + "data_type": True, "data_shape": True, "value_range": True, "data_value": True, - "additional_info": np.mean, + "additional_info": None, "logger_handler": None, }, np.array([[0, 1], [1, 2]]), - "test data statistics:\nShape: (2, 2)\nValue range: (0, 2)\nValue: [[0 1]\n [1 2]]\nAdditional info: 1.0", + "test data statistics:\nType: \nShape: (2, 2)\nValue range: (0, 2)\nValue: [[0 1]\n [1 2]]", ] TEST_CASE_6 = [ { "prefix": "test data", + "data_type": True, + "data_shape": True, + "value_range": True, + "data_value": True, + "additional_info": np.mean, + "logger_handler": None, + }, + np.array([[0, 1], [1, 2]]), + ( + "test data statistics:\nType: \nShape: (2, 2)\n" + "Value range: (0, 2)\nValue: [[0 1]\n [1 2]]\nAdditional info: 1.0" + ), +] + +TEST_CASE_7 = [ + { + "prefix": "test data", + "data_type": True, "data_shape": True, "value_range": True, "data_value": True, @@ -96,25 +119,26 @@ }, torch.tensor([[0, 1], [1, 2]]), ( - "test data statistics:\nShape: torch.Size([2, 2])\nValue range: (0, 2)\n" + "test data statistics:\nType: \nShape: torch.Size([2, 2])\nValue range: (0, 2)\n" "Value: tensor([[0, 1],\n [1, 2]])\nAdditional info: 1.0" ), ] -TEST_CASE_7 = [ +TEST_CASE_8 = [ np.array([[0, 1], [1, 2]]), - "test data statistics:\nShape: (2, 2)\nValue range: (0, 2)\nValue: [[0 1]\n [1 2]]\nAdditional info: 1.0\n", + "test data statistics:\nType: \nShape: (2, 2)\nValue range: (0, 2)\n" + "Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n", ] class TestDataStats(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_value(self, input_param, input_data, expected_print): transform = DataStats(**input_param) _ = transform(input_data) self.assertEqual(transform.output, expected_print) - @parameterized.expand([TEST_CASE_7]) + @parameterized.expand([TEST_CASE_8]) def test_file(self, input_data, expected_print): with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_data_stats.log") @@ -122,6 +146,7 @@ def test_file(self, input_data, expected_print): handler.setLevel(logging.INFO) input_param = { "prefix": "test data", + "data_type": True, "data_shape": True, "value_range": True, "data_value": True, diff --git a/tests/test_data_statsd.py b/tests/test_data_statsd.py index bacd70194a..7ac346b275 100644 --- a/tests/test_data_statsd.py +++ b/tests/test_data_statsd.py @@ -24,10 +24,12 @@ { "keys": "img", "prefix": "test data", + "data_type": False, "data_shape": False, "value_range": False, "data_value": False, "additional_info": None, + "logger_handler": None, }, {"img": np.array([[0, 1], [1, 2]])}, "test data statistics:", @@ -37,97 +39,138 @@ { "keys": "img", "prefix": "test data", - "data_shape": True, + "data_type": True, + "data_shape": False, "value_range": False, "data_value": False, "additional_info": None, + "logger_handler": None, }, {"img": np.array([[0, 1], [1, 2]])}, - "test data statistics:\nShape: (2, 2)", + "test data statistics:\nType: ", ] TEST_CASE_3 = [ { "keys": "img", "prefix": "test data", + "data_type": True, "data_shape": True, - "value_range": True, + "value_range": False, "data_value": False, "additional_info": None, + "logger_handler": None, }, {"img": np.array([[0, 1], [1, 2]])}, - "test data statistics:\nShape: (2, 2)\nValue range: (0, 2)", + "test data statistics:\nType: \nShape: (2, 2)", ] TEST_CASE_4 = [ { "keys": "img", "prefix": "test data", + "data_type": True, "data_shape": True, "value_range": True, - "data_value": True, + "data_value": False, "additional_info": None, + "logger_handler": None, }, {"img": np.array([[0, 1], [1, 2]])}, - "test data statistics:\nShape: (2, 2)\nValue range: (0, 2)\nValue: [[0 1]\n [1 2]]", + "test data statistics:\nType: \nShape: (2, 2)\nValue range: (0, 2)", ] TEST_CASE_5 = [ { "keys": "img", "prefix": "test data", + "data_type": True, "data_shape": True, "value_range": True, "data_value": True, - "additional_info": np.mean, + "additional_info": None, + "logger_handler": None, }, {"img": np.array([[0, 1], [1, 2]])}, - "test data statistics:\nShape: (2, 2)\nValue range: (0, 2)\nValue: [[0 1]\n [1 2]]\nAdditional info: 1.0", + "test data statistics:\nType: \nShape: (2, 2)\nValue range: (0, 2)\nValue: [[0 1]\n [1 2]]", ] TEST_CASE_6 = [ { "keys": "img", "prefix": "test data", + "data_type": True, + "data_shape": True, + "value_range": True, + "data_value": True, + "additional_info": np.mean, + "logger_handler": None, + }, + {"img": np.array([[0, 1], [1, 2]])}, + ( + "test data statistics:\nType: \nShape: (2, 2)\n" + "Value range: (0, 2)\nValue: [[0 1]\n [1 2]]\nAdditional info: 1.0" + ), +] + +TEST_CASE_7 = [ + { + "keys": "img", + "prefix": "test data", + "data_type": True, "data_shape": True, "value_range": True, "data_value": True, "additional_info": lambda x: torch.mean(x.float()), + "logger_handler": None, }, {"img": torch.tensor([[0, 1], [1, 2]])}, ( - "test data statistics:\nShape: torch.Size([2, 2])\nValue range: (0, 2)\n" + "test data statistics:\nType: \nShape: torch.Size([2, 2])\nValue range: (0, 2)\n" "Value: tensor([[0, 1],\n [1, 2]])\nAdditional info: 1.0" ), ] -TEST_CASE_7 = [ +TEST_CASE_8 = [ { "keys": ("img", "affine"), "prefix": ("image", "affine"), + "data_type": True, "data_shape": True, "value_range": (True, False), "data_value": (False, True), "additional_info": (np.mean, None), }, {"img": np.array([[0, 1], [1, 2]]), "affine": np.eye(2, 2)}, - "affine statistics:\nShape: (2, 2)\nValue: [[1. 0.]\n [0. 1.]]", + "affine statistics:\nType: \nShape: (2, 2)\nValue: [[1. 0.]\n [0. 1.]]", ] -TEST_CASE_8 = [ +TEST_CASE_9 = [ {"img": np.array([[0, 1], [1, 2]])}, - "test data statistics:\nShape: (2, 2)\nValue range: (0, 2)\nValue: [[0 1]\n [1 2]]\nAdditional info: 1.0\n", + "test data statistics:\nType: \nShape: (2, 2)\nValue range: (0, 2)\n" + "Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n", ] class TestDataStatsd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + @parameterized.expand( + [ + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + ] + ) def test_value(self, input_param, input_data, expected_print): transform = DataStatsd(**input_param) _ = transform(input_data) self.assertEqual(transform.printer.output, expected_print) - @parameterized.expand([TEST_CASE_8]) + @parameterized.expand([TEST_CASE_9]) def test_file(self, input_data, expected_print): with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "test_stats.log")