Skip to content
6 changes: 4 additions & 2 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,14 @@ N-Dim Fourier Transform
Meta Object
-----------
.. automodule:: monai.data.meta_obj
:members:
:members:

MetaTensor
----------
.. autoclass:: monai.data.MetaTensor
:members:
:members:
:show-inheritance:
:inherited-members: MetaObj



Expand Down
6 changes: 3 additions & 3 deletions monai/data/meta_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class MetaObj:

* For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the
first instance of `MetaObj` if `a.is_batch` is False
(For batched data, the metdata will be shallow copied for efficiency purposes).
(For batched data, the metadata will be shallow copied for efficiency purposes).

"""

Expand Down Expand Up @@ -185,7 +185,7 @@ def __repr__(self) -> str:

@property
def meta(self) -> dict:
"""Get the meta."""
"""Get the meta. Defaults to ``{}``."""
return self._meta if hasattr(self, "_meta") else MetaObj.get_default_meta()

@meta.setter
Expand All @@ -197,7 +197,7 @@ def meta(self, d) -> None:

@property
def applied_operations(self) -> list[dict]:
"""Get the applied operations."""
"""Get the applied operations. Defaults to ``[]``."""
if hasattr(self, "_applied_operations"):
return self._applied_operations
return MetaObj.get_default_applied_operations()
Expand Down
109 changes: 91 additions & 18 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,21 @@ def __init__(
**_kwargs,
) -> None:
"""
If `meta` is given, use it. Else, if `meta` exists in the input tensor, use it.
Else, use the default value. Similar for the affine, except this could come from
four places.
Priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`.
Args:
x: initial array for the MetaTensor. Can be a list, tuple, NumPy ndarray, scalar, and other types.
affine: optional 4x4 array.
meta: dictionary of metadata.
applied_operations: list of previously applied operations on the MetaTensor,
the list is typically maintained by `monai.transforms.TraceableTransform`.
See also: :py:class:`monai.transforms.TraceableTransform`
_args: additional args (currently not in use in this constructor).
_kwargs: additional kwargs (currently not in use in this constructor).

Note:
If a `meta` dictionary is given, use it. Else, if `meta` exists in the input tensor `x`, use it.
Else, use the default value. Similar for the affine, except this could come from
four places, priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`.

"""
super().__init__()
# set meta
Expand Down Expand Up @@ -177,7 +188,7 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
the input type was not `MetaTensor`, then no modifications will have been
made. If global parameters have been set to false (e.g.,
`not get_track_meta()`), then any `MetaTensor` will be converted to
`torch.Tensor`. Else, metadata will be propogated as necessary (see
`torch.Tensor`. Else, metadata will be propagated as necessary (see
:py:func:`MetaTensor._copy_meta`).
"""
out = []
Expand Down Expand Up @@ -328,34 +339,88 @@ def as_tensor(self) -> torch.Tensor:
"""
return self.as_subclass(torch.Tensor) # type: ignore

def as_dict(self, key: str) -> dict:
def get_array(self, output_type=np.ndarray, dtype=None, *_args, **_kwargs):
Comment thread
wyli marked this conversation as resolved.
"""
Returns a new array in `output_type`, the array shares the same underlying storage when the output is a
numpy array. Changes to self tensor will be reflected in the ndarray and vice versa.

Args:
output_type: output type, see also: :py:func:`monai.utils.convert_data_type`.
dtype: dtype of output data. Converted to correct library type (e.g.,
`np.float32` is converted to `torch.float32` if output type is `torch.Tensor`).
If left blank, it remains unchanged.
_args: currently unused parameters.
_kwargs: currently unused parameters.
"""
return convert_data_type(self, output_type=output_type, dtype=dtype, wrap_sequence=True)[0]

def set_array(self, src, non_blocking=False, *_args, **_kwargs):
"""
Copies the elements from src into self tensor and returns self.
The src tensor must be broadcastable with the self tensor.
It may be of a different data type or reside on a different device.

See also: `https://pytorch.org/docs/stable/generated/torch.Tensor.copy_.html`

Args:
src: the source tensor to copy from.
non_blocking: if True and this copy is between CPU and GPU, the copy may occur
asynchronously with respect to the host. For other cases, this argument has no effect.
_args: currently unused parameters.
_kwargs: currently unused parameters.
"""
src: torch.Tensor = convert_to_tensor(src, track_meta=False, wrap_sequence=True)
return self.copy_(src, non_blocking=non_blocking)

@property
def array(self):
"""
Returns a numpy array of ``self``. The array and ``self`` shares the same underlying storage if self is on cpu.
Changes to ``self`` (it's a subclass of torch.Tensor) will be reflected in the ndarray and vice versa.
If ``self`` is not on cpu, the call will move the array to cpu and then the storage is not shared.

:getter: see also: :py:func:`MetaTensor.get_array()`
:setter: see also: :py:func:`MetaTensor.set_array()`
"""
return self.get_array()

@array.setter
def array(self, src) -> None:
"""A default setter using ``self.set_array()``"""
self.set_array(src)

def as_dict(self, key: str, output_type=torch.Tensor, dtype=None) -> dict:
"""
Get the object as a dictionary for backwards compatibility.
This method makes a copy of the objects.
This method does not make a deep copy of the objects.

Args:
key: Base key to store main data. The key for the metadata will be
determined using `PostFix.meta`.
key: Base key to store main data. The key for the metadata will be determined using `PostFix`.
output_type: `torch.Tensor` or `np.ndarray` for the main data.
dtype: dtype of output data. Converted to correct library type (e.g.,
`np.float32` is converted to `torch.float32` if output type is `torch.Tensor`).
If left blank, it remains unchanged.

Return:
A dictionary consisting of two keys, the main data (stored under `key`) and
the metadata.
A dictionary consisting of three keys, the main data (stored under `key`) and the metadata.
"""
if output_type not in (torch.Tensor, np.ndarray):
raise ValueError(f"output_type must be torch.Tensor or np.ndarray, got {output_type}.")
return {
key: self.as_tensor().clone().detach(),
PostFix.meta(key): deepcopy(self.meta),
PostFix.transforms(key): deepcopy(self.applied_operations),
key: self.get_array(output_type=output_type, dtype=dtype),
PostFix.meta(key): self.meta,
PostFix.transforms(key): self.applied_operations,
}

def astype(self, dtype, device=None, *unused_args, **unused_kwargs):
def astype(self, dtype, device=None, *_args, **_kwargs):
"""
Cast to ``dtype``, sharing data whenever possible.

Args:
dtype: dtypes such as np.float32, torch.float, "np.float32", float.
device: the device if `dtype` is a torch data type.
unused_args: additional args (currently unused).
unused_kwargs: additional kwargs (currently unused).
_args: additional args (currently unused).
_kwargs: additional kwargs (currently unused).

Returns:
data array instance
Expand All @@ -376,7 +441,7 @@ def astype(self, dtype, device=None, *unused_args, **unused_kwargs):

@property
def affine(self) -> torch.Tensor:
"""Get the affine."""
"""Get the affine. Defaults to ``torch.eye(4, dtype=torch.float64)``"""
return self.meta.get("affine", self.get_default_affine())

@affine.setter
Expand All @@ -400,6 +465,13 @@ def new_empty(self, size, dtype=None, device=None, requires_grad=False):
self.as_tensor().new_empty(size=size, dtype=dtype, device=device, requires_grad=requires_grad)
)

def clone(self):
if self.data_ptr() == 0:
new_inst = MetaTensor(self.as_tensor().clone())
new_inst.__dict__ = deepcopy(self.__dict__)
return new_inst
return super().clone()

@staticmethod
def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict, simple_keys: bool = False):
"""
Expand All @@ -409,6 +481,7 @@ def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict, simple_keys: bool
Args:
im: Input image (`np.ndarray` or `torch.Tensor`)
meta: Metadata dictionary.
simple_keys: whether to keep only a simple subset of metadata keys.

Returns:
By default, a `MetaTensor` is returned.
Expand Down
25 changes: 21 additions & 4 deletions monai/transforms/meta_utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
Class names are ended with 'd' to denote dictionary-based transforms.
"""

from typing import Dict, Hashable, Mapping
from typing import Dict, Hashable, Mapping, Sequence, Union

from monai.config.type_definitions import NdarrayOrTensor
import numpy as np
import torch

from monai.config.type_definitions import KeysCollection, NdarrayOrTensor
from monai.data.meta_tensor import MetaTensor
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import MapTransform
from monai.utils.enums import PostFix, TransformBackends
from monai.utils.misc import ensure_tuple_rep

__all__ = [
"FromMetaTensord",
Expand All @@ -43,11 +47,24 @@ class FromMetaTensord(MapTransform, InvertibleTransform):

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self, keys: KeysCollection, data_type: Union[Sequence[str], str] = "tensor", allow_missing_keys: bool = False
):
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
data_type: target data type to convert, should be "tensor" or "numpy".
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.as_tensor_output = tuple(d == "tensor" for d in ensure_tuple_rep(data_type, len(self.keys)))

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
for key, t in self.key_iterator(d, self.as_tensor_output):
im: MetaTensor = d[key] # type: ignore
d.update(im.as_dict(key))
d.update(im.as_dict(key, output_type=torch.Tensor if t else np.ndarray))
self.push_transform(d, key)
return d

Expand Down
6 changes: 4 additions & 2 deletions monai/visualize/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def blend_images(
if image.shape[1:] != label.shape[1:]:
raise ValueError("image and label should have matching spatial sizes.")
if isinstance(alpha, (np.ndarray, torch.Tensor)):
if image.shape[1:] != alpha.shape[1:]: # type: ignore
if image.shape[1:] != alpha.shape[1:]: # pytype: disable=attribute-error,invalid-directive

raise ValueError("if alpha is image, size should match input image and label.")

# rescale arrays to [0, 1] if desired
Expand Down Expand Up @@ -220,6 +221,7 @@ def get_label_rgb(cmap: str, label: NdarrayOrTensor):
w_label = np.full_like(label, alpha)
if transparent_background:
# where label == 0 (background), set label alpha to 0
w_label[label == 0] = 0 # type: ignore
w_label[label == 0] = 0 # pytype: disable=unsupported-operands

w_image = 1 - w_label
return w_image * image + w_label * label_rgb
2 changes: 1 addition & 1 deletion tests/test_highresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_script(self):
input_param, input_shape, expected_shape = TEST_CASE_1
net = HighResNet(**input_param)
test_data = torch.randn(input_shape)
test_script_save(net, test_data)
test_script_save(net, test_data, rtol=1e-4, atol=1e-4)


if __name__ == "__main__":
Expand Down
19 changes: 19 additions & 0 deletions tests/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def test_copy(self, device, dtype):
# clone
a = m.clone()
self.check(a, m, ids=False)
a = MetaTensor([[]], device=device, dtype=dtype)
self.check(a, deepcopy(a), ids=False)

@parameterized.expand(TESTS)
def test_add(self, device, dtype):
Expand Down Expand Up @@ -536,6 +538,23 @@ def test_array_function(self, device="cpu", dtype=float):
c > torch.as_tensor([1.0, 1.0, 1.0], device=device), torch.as_tensor([False, True, True], device=device)
)

@parameterized.expand(TESTS)
def test_numpy(self, device=None, dtype=None):
"""device, dtype"""
t = MetaTensor([0.0], device=device, dtype=dtype)
self.assertIsInstance(t, MetaTensor)
assert_allclose(t.array, np.asarray([0.0]))
t.array = np.asarray([1.0])
self.check_meta(t, MetaTensor([1.0]))
assert_allclose(t.as_tensor(), torch.as_tensor([1.0]))
t.array = [2.0]
self.check_meta(t, MetaTensor([2.0]))
assert_allclose(t.as_tensor(), torch.as_tensor([2.0]))
if not t.is_cuda:
t.array[0] = torch.as_tensor(3.0, device=device, dtype=dtype)
self.check_meta(t, MetaTensor([3.0]))
assert_allclose(t.as_tensor(), torch.as_tensor([3.0]))


if __name__ == "__main__":
unittest.main()
17 changes: 12 additions & 5 deletions tests/test_to_from_meta_tensord.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from copy import deepcopy
from typing import Optional, Union

import numpy as np
import torch
from parameterized import parameterized

Expand All @@ -28,7 +29,8 @@
TESTS = []
for _device in TEST_DEVICES:
for _dtype in DTYPES:
TESTS.append((*_device, *_dtype))
for _data_type in ("tensor", "numpy"):
TESTS.append((*_device, *_dtype, _data_type))


def rand_string(min_len=5, max_len=10):
Expand Down Expand Up @@ -67,7 +69,7 @@ def check(
ids: bool = True,
device: Optional[Union[str, torch.device]] = None,
meta: bool = True,
check_ids: bool = True,
check_ids: bool = False,
**kwargs,
):
if device is None:
Expand Down Expand Up @@ -97,7 +99,7 @@ def check(
self.check_ids(out.meta, orig.meta, ids)

@parameterized.expand(TESTS)
def test_from_to_meta_tensord(self, device, dtype):
def test_from_to_meta_tensord(self, device, dtype, data_type="tensor"):
m1 = self.get_im(device=device, dtype=dtype)
m2 = self.get_im(device=device, dtype=dtype)
m3 = self.get_im(device=device, dtype=dtype)
Expand All @@ -106,7 +108,7 @@ def test_from_to_meta_tensord(self, device, dtype):
m1_aff = m1.affine

# FROM -> forward
t_from_meta = FromMetaTensord(["m1", "m2"])
t_from_meta = FromMetaTensord(["m1", "m2"], data_type=data_type)
d_dict = t_from_meta(d_metas)

self.assertEqual(
Expand All @@ -122,7 +124,10 @@ def test_from_to_meta_tensord(self, device, dtype):
],
)
self.check(d_dict["m3"], m3, ids=True) # unchanged
self.check(d_dict["m1"], m1.as_tensor(), ids=False)
if data_type == "tensor":
self.check(d_dict["m1"], m1.as_tensor(), ids=False)
else:
self.assertIsInstance(d_dict["m1"], np.ndarray)
meta_out = {k: v for k, v in d_dict["m1_meta_dict"].items() if k != "affine"}
aff_out = d_dict["m1_meta_dict"]["affine"]
self.check(aff_out, m1_aff, ids=False)
Expand All @@ -131,6 +136,8 @@ def test_from_to_meta_tensord(self, device, dtype):
# FROM -> inverse
d_meta_dict_meta = t_from_meta.inverse(d_dict)
self.assertEqual(sorted(d_meta_dict_meta.keys()), ["m1", "m2", "m3"])
if data_type == "numpy":
m1, m1_aff = m1.cpu(), m1_aff.cpu()
self.check(d_meta_dict_meta["m1"], m1, ids=False)
meta_out = {k: v for k, v in d_meta_dict_meta["m1"].meta.items() if k != "affine"}
aff_out = d_meta_dict_meta["m1"].affine
Expand Down