Skip to content
23 changes: 23 additions & 0 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple

import numpy as np
import torch

from monai import transforms
from monai.config import KeysCollection
from monai.utils import MAX_SEED, ensure_tuple

Expand Down Expand Up @@ -45,6 +47,27 @@ def apply_transform(transform: Callable, data, map_items: bool = True):
return [transform(item) for item in data]
return transform(data)
except Exception as e:

if not isinstance(transform, transforms.compose.Compose):
# log the input data information of exact transform in the transform chain
datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False)
datastats._logger.info("input data information of the runtime error transform:")
if isinstance(data, (list, tuple)):
data = data[0]

def _log_stats(data, prefix: Optional[str] = "Data"):
if isinstance(data, (np.ndarray, torch.Tensor)):
# log data type, shape, range for array
datastats(img=data, data_shape=True, value_range=True, prefix=prefix) # type: ignore
else:
# log data type and value for other meta data
datastats(img=data, data_value=True, prefix=prefix)

if isinstance(data, dict):
for k, v in data.items():
_log_stats(data=v, prefix=k)
else:
_log_stats(data=data)
raise RuntimeError(f"applying transform {transform}") from e


Expand Down
6 changes: 6 additions & 0 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ class DataStats(Transform):
def __init__(
self,
prefix: str = "Data",
data_type: bool = True,
data_shape: bool = True,
value_range: bool = True,
data_value: bool = False,
Expand All @@ -403,6 +404,7 @@ def __init__(
"""
Args:
prefix: will be printed in format: "{prefix} statistics".
data_type: whether to show the type of input data.
data_shape: whether to show the shape of input data.
value_range: whether to show the value range of input data.
data_value: whether to show the raw value of input data.
Expand All @@ -419,6 +421,7 @@ def __init__(
if not isinstance(prefix, str):
raise AssertionError("prefix must be a string.")
self.prefix = prefix
self.data_type = data_type
self.data_shape = data_shape
self.value_range = value_range
self.data_value = data_value
Expand All @@ -438,6 +441,7 @@ def __call__(
self,
img: NdarrayTensor,
prefix: Optional[str] = None,
data_type: Optional[bool] = None,
data_shape: Optional[bool] = None,
value_range: Optional[bool] = None,
data_value: Optional[bool] = None,
Expand All @@ -448,6 +452,8 @@ def __call__(
"""
lines = [f"{prefix or self.prefix} statistics:"]

if self.data_type if data_type is None else data_type:
lines.append(f"Type: {type(img)}")
if self.data_shape if data_shape is None else data_shape:
lines.append(f"Shape: {img.shape}")
if self.value_range if value_range is None else value_range:
Expand Down
9 changes: 7 additions & 2 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ def __init__(
self,
keys: KeysCollection,
prefix: Union[Sequence[str], str] = "Data",
data_type: Union[Sequence[bool], bool] = True,
data_shape: Union[Sequence[bool], bool] = True,
value_range: Union[Sequence[bool], bool] = True,
data_value: Union[Sequence[bool], bool] = False,
Expand All @@ -520,6 +521,8 @@ def __init__(
See also: :py:class:`monai.transforms.compose.MapTransform`
prefix: will be printed in format: "{prefix} statistics".
it also can be a sequence of string, each element corresponds to a key in ``keys``.
data_type: whether to show the type of input data.
it also can be a sequence of bool, each element corresponds to a key in ``keys``.
data_shape: whether to show the shape of input data.
it also can be a sequence of bool, each element corresponds to a key in ``keys``.
value_range: whether to show the value range of input data.
Expand All @@ -538,6 +541,7 @@ def __init__(
"""
super().__init__(keys, allow_missing_keys)
self.prefix = ensure_tuple_rep(prefix, len(self.keys))
self.data_type = ensure_tuple_rep(data_type, len(self.keys))
self.data_shape = ensure_tuple_rep(data_shape, len(self.keys))
self.value_range = ensure_tuple_rep(value_range, len(self.keys))
self.data_value = ensure_tuple_rep(data_value, len(self.keys))
Expand All @@ -547,12 +551,13 @@ def __init__(

def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]:
d = dict(data)
for key, prefix, data_shape, value_range, data_value, additional_info in self.key_iterator(
d, self.prefix, self.data_shape, self.value_range, self.data_value, self.additional_info
for key, prefix, data_type, data_shape, value_range, data_value, additional_info in self.key_iterator(
d, self.prefix, self.data_type, self.data_shape, self.value_range, self.data_value, self.additional_info
):
d[key] = self.printer(
d[key],
prefix,
data_type,
data_shape,
value_range,
data_value,
Expand Down
51 changes: 38 additions & 13 deletions tests/test_data_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TEST_CASE_1 = [
{
"prefix": "test data",
"data_type": False,
"data_shape": False,
"value_range": False,
"data_value": False,
Expand All @@ -36,58 +37,80 @@
TEST_CASE_2 = [
{
"prefix": "test data",
"data_shape": True,
"data_type": True,
"data_shape": False,
"value_range": False,
"data_value": False,
"additional_info": None,
"logger_handler": None,
},
np.array([[0, 1], [1, 2]]),
"test data statistics:\nShape: (2, 2)",
"test data statistics:\nType: <class 'numpy.ndarray'>",
]

TEST_CASE_3 = [
{
"prefix": "test data",
"data_type": True,
"data_shape": True,
"value_range": True,
"value_range": False,
"data_value": False,
"additional_info": None,
"logger_handler": None,
},
np.array([[0, 1], [1, 2]]),
"test data statistics:\nShape: (2, 2)\nValue range: (0, 2)",
"test data statistics:\nType: <class 'numpy.ndarray'>\nShape: (2, 2)",
]

TEST_CASE_4 = [
{
"prefix": "test data",
"data_type": True,
"data_shape": True,
"value_range": True,
"data_value": True,
"data_value": False,
"additional_info": None,
"logger_handler": None,
},
np.array([[0, 1], [1, 2]]),
"test data statistics:\nShape: (2, 2)\nValue range: (0, 2)\nValue: [[0 1]\n [1 2]]",
"test data statistics:\nType: <class 'numpy.ndarray'>\nShape: (2, 2)\nValue range: (0, 2)",
]

TEST_CASE_5 = [
{
"prefix": "test data",
"data_type": True,
"data_shape": True,
"value_range": True,
"data_value": True,
"additional_info": np.mean,
"additional_info": None,
"logger_handler": None,
},
np.array([[0, 1], [1, 2]]),
"test data statistics:\nShape: (2, 2)\nValue range: (0, 2)\nValue: [[0 1]\n [1 2]]\nAdditional info: 1.0",
"test data statistics:\nType: <class 'numpy.ndarray'>\nShape: (2, 2)\nValue range: (0, 2)\nValue: [[0 1]\n [1 2]]",
]

TEST_CASE_6 = [
{
"prefix": "test data",
"data_type": True,
"data_shape": True,
"value_range": True,
"data_value": True,
"additional_info": np.mean,
"logger_handler": None,
},
np.array([[0, 1], [1, 2]]),
(
"test data statistics:\nType: <class 'numpy.ndarray'>\nShape: (2, 2)\n"
"Value range: (0, 2)\nValue: [[0 1]\n [1 2]]\nAdditional info: 1.0"
),
]

TEST_CASE_7 = [
{
"prefix": "test data",
"data_type": True,
"data_shape": True,
"value_range": True,
"data_value": True,
Expand All @@ -96,32 +119,34 @@
},
torch.tensor([[0, 1], [1, 2]]),
(
"test data statistics:\nShape: torch.Size([2, 2])\nValue range: (0, 2)\n"
"test data statistics:\nType: <class 'torch.Tensor'>\nShape: torch.Size([2, 2])\nValue range: (0, 2)\n"
"Value: tensor([[0, 1],\n [1, 2]])\nAdditional info: 1.0"
),
]

TEST_CASE_7 = [
TEST_CASE_8 = [
np.array([[0, 1], [1, 2]]),
"test data statistics:\nShape: (2, 2)\nValue range: (0, 2)\nValue: [[0 1]\n [1 2]]\nAdditional info: 1.0\n",
"test data statistics:\nType: <class 'numpy.ndarray'>\nShape: (2, 2)\nValue range: (0, 2)\n"
"Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n",
]


class TestDataStats(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])
def test_value(self, input_param, input_data, expected_print):
transform = DataStats(**input_param)
_ = transform(input_data)
self.assertEqual(transform.output, expected_print)

@parameterized.expand([TEST_CASE_7])
@parameterized.expand([TEST_CASE_8])
def test_file(self, input_data, expected_print):
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, "test_data_stats.log")
handler = logging.FileHandler(filename, mode="w")
handler.setLevel(logging.INFO)
input_param = {
"prefix": "test data",
"data_type": True,
"data_shape": True,
"value_range": True,
"data_value": True,
Expand Down
Loading