Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 160 additions & 58 deletions monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Hashable, Mapping, Optional, Tuple
import warnings
from contextlib import contextmanager
from typing import Any, Hashable, Mapping, Optional, Tuple

import torch

Expand All @@ -22,18 +25,31 @@

class TraceableTransform(Transform):
"""
Maintains a stack of applied transforms. The stack is inserted as pairs of
`trace_key: list of transforms` to each data dictionary.
Maintains a stack of applied transforms to data.

Data can be one of two types:
1. A `MetaTensor` (this is the preferred data type).
2. A dictionary of data containing arrays/tensors and auxiliary metadata. In
this case, a key must be supplied (this dictionary-based approach is deprecated).

If `data` is of type `MetaTensor`, then the applied transform will be added to ``data.applied_operations``.

If `data` is a dictionary, then one of two things can happen:
1. If data[key] is a `MetaTensor`, the applied transform will be added to ``data[key].applied_operations``.
2. Else, the applied transform will be appended to an adjacent list using
`trace_key`. If, for example, the key is `image`, then the transform
will be appended to `image_transforms` (this dictionary-based approach is deprecated).

Hopefully it is clear that there are three total possibilities:
1. data is `MetaTensor`
2. data is dictionary, data[key] is `MetaTensor`
3. data is dictionary, data[key] is not `MetaTensor` (this is a deprecated approach).

The ``__call__`` method of this transform class must be implemented so
that the transformation information for each key is stored when
``__call__`` is called. If the transforms were applied to keys "image" and
"label", there will be two extra keys in the dictionary: "image_transforms"
and "label_transforms" (based on `TraceKeys.KEY_SUFFIX`). Each list
contains a list of the transforms applied to that key.
that the transformation information is stored during the data transformation.

The information in ``data[key_transform]`` will be compatible with the
default collate since it only stores strings, numbers and arrays.
The information in the stack of applied transforms must be compatible with the
default collate, by only storing strings, numbers and arrays.

`tracing` could be enabled by `self.set_tracing` or setting
`MONAI_TRACE_TRANSFORM` when initializing the class.
Expand All @@ -49,44 +65,152 @@ def set_tracing(self, tracing: bool) -> None:
def trace_key(key: Hashable = None):
"""The key to store the stack of applied transforms."""
if key is None:
return TraceKeys.KEY_SUFFIX
return str(key) + TraceKeys.KEY_SUFFIX
return f"{TraceKeys.KEY_SUFFIX}"
return f"{key}{TraceKeys.KEY_SUFFIX}"

def push_transform(
self, data: Mapping, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None
) -> None:
"""Push to a stack of applied transforms for that key."""
def get_transform_info(
self, data, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None
) -> dict:
"""
Return a dictionary with the relevant information pertaining to an applied transform.

if not self.tracing:
return
Args:
data: input data. Can be dictionary or MetaTensor. We can use `shape` to
determine the original size of the object (unless that has been given
explicitly, see `orig_size`).
key: if data is a dictionary, data[key] will be modified.
extra_info: if desired, any extra information pertaining to the applied
transform can be stored in this dictionary. These are often needed for
computing the inverse transformation.
orig_size: sometimes during the inverse it is useful to know what the size
of the original image was, in which case it can be supplied here.

Returns:
Dictionary of data pertaining to the applied transformation.
"""
info = {TraceKeys.CLASS_NAME: self.__class__.__name__, TraceKeys.ID: id(self)}
if orig_size is not None:
info[TraceKeys.ORIG_SIZE] = orig_size
elif key in data and hasattr(data[key], "shape"):
elif isinstance(data, Mapping) and key in data and hasattr(data[key], "shape"):
info[TraceKeys.ORIG_SIZE] = data[key].shape[1:]
elif hasattr(data, "shape"):
info[TraceKeys.ORIG_SIZE] = data.shape[1:]
if extra_info is not None:
info[TraceKeys.EXTRA_INFO] = extra_info
# If class is randomizable transform, store whether the transform was actually performed (based on `prob`)
if hasattr(self, "_do_transform"): # RandomizableTransform
info[TraceKeys.DO_TRANSFORM] = self._do_transform # type: ignore
return info

if key in data and isinstance(data[key], MetaTensor):
data[key].push_applied_operation(info)
else:
# If this is the first, create list
if self.trace_key(key) not in data:
if not isinstance(data, dict):
data = dict(data)
data[self.trace_key(key)] = []
data[self.trace_key(key)].append(info)

def pop_transform(self, data: Mapping, key: Hashable = None):
"""Remove the most recent applied transform."""
def push_transform(
self, data, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None
) -> None:
"""
Push to a stack of applied transforms.

Args:
data: dictionary of data or `MetaTensor`.
key: if data is a dictionary, data[key] will be modified.
extra_info: if desired, any extra information pertaining to the applied
transform can be stored in this dictionary. These are often needed for
computing the inverse transformation.
orig_size: sometimes during the inverse it is useful to know what the size
of the original image was, in which case it can be supplied here.

Returns:
None, but data has been updated to store the applied transformation.
"""
if not self.tracing:
return
if key in data and isinstance(data[key], MetaTensor):
return data[key].pop_applied_operation()
return data.get(self.trace_key(key), []).pop()
info = self.get_transform_info(data, key, extra_info, orig_size)

if isinstance(data, MetaTensor):
data.push_applied_operation(info)
elif isinstance(data, Mapping):
if key in data and isinstance(data[key], MetaTensor):
data[key].push_applied_operation(info)
else:
# If this is the first, create list
if self.trace_key(key) not in data:
if not isinstance(data, dict):
data = dict(data)
data[self.trace_key(key)] = []
data[self.trace_key(key)].append(info)
else:
warnings.warn(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}. {info} not tracked.")

def check_transforms_match(self, transform: Mapping) -> None:
"""Check transforms are of same instance."""
xform_id = transform.get(TraceKeys.ID, "")
if xform_id == id(self):
return
# TraceKeys.NONE to skip the id check
if xform_id == TraceKeys.NONE:
return
xform_name = transform.get(TraceKeys.CLASS_NAME, "")
# basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID)
if torch.multiprocessing.get_start_method() in ("spawn", None) and xform_name == self.__class__.__name__:
return
raise RuntimeError(
f"Error {self.__class__.__name__} getting the most recently "
f"applied invertible transform {xform_name} {xform_id} != {id(self)}."
)

def get_most_recent_transform(self, data, key: Hashable = None, check: bool = True, pop: bool = False):
"""
Get most recent transform for the stack.

Args:
data: dictionary of data or `MetaTensor`.
key: if data is a dictionary, data[key] will be modified.
check: if true, check that `self` is the same type as the most recently-applied transform.
pop: if true, remove the transform as it is returned.

Returns:
Dictionary of most recently applied transform

Raises:
- RuntimeError: data is neither `MetaTensor` nor dictionary
"""
if not self.tracing:
raise RuntimeError("Transform Tracing must be enabled to get the most recent transform.")
if isinstance(data, MetaTensor):
all_transforms = data.applied_operations
elif isinstance(data, Mapping):
if key in data and isinstance(data[key], MetaTensor):
all_transforms = data[key].applied_operations
else:
all_transforms = data.get(self.trace_key(key), MetaTensor.get_default_applied_operations())
else:
raise ValueError(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}.")
if check:
self.check_transforms_match(all_transforms[-1])
return all_transforms.pop() if pop else all_transforms[-1]

def pop_transform(self, data, key: Hashable = None, check: bool = True):
"""
Return and pop the most recent transform.

Args:
data: dictionary of data or `MetaTensor`
key: if data is a dictionary, data[key] will be modified
check: if true, check that `self` is the same type as the most recently-applied transform.

Returns:
Dictionary of most recently applied transform

Raises:
- RuntimeError: data is neither `MetaTensor` nor dictionary
"""
return self.get_most_recent_transform(data, key, check, pop=True)

@contextmanager
def trace_transform(self, to_trace: bool):
"""Temporarily set the tracing status of a transform with a context manager."""
prev = self.tracing
self.tracing = to_trace
yield
self.tracing = prev


class InvertibleTransform(TraceableTransform):
Expand All @@ -103,7 +227,7 @@ class InvertibleTransform(TraceableTransform):
different parameters being passed to each label (e.g., different
interpolation for image and label).

- the inverse transforms are applied in a last- in-first-out order. As
- the inverse transforms are applied in a last-in-first-out order. As
the inverse is applied, its entry is removed from the list detailing
the applied transformations. That is to say that during the forward
pass, the list of applied transforms grows, and then during the
Expand All @@ -126,29 +250,7 @@ class InvertibleTransform(TraceableTransform):

"""

def check_transforms_match(self, transform: Mapping) -> None:
"""Check transforms are of same instance."""
xform_name = transform.get(TraceKeys.CLASS_NAME, "")
xform_id = transform.get(TraceKeys.ID, "")
if xform_id == id(self):
return
# basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID)
if torch.multiprocessing.get_start_method() in ("spawn", None) and xform_name == self.__class__.__name__:
return
raise RuntimeError(f"Error inverting the most recently applied invertible transform {xform_name} {xform_id}.")

def get_most_recent_transform(self, data: Mapping, key: Hashable = None):
"""Get most recent transform."""
if not self.tracing:
raise RuntimeError("Transform Tracing must be enabled to get the most recent transform.")
if isinstance(data[key], MetaTensor):
transform = data[key].applied_operations[-1]
else:
transform = data[self.trace_key(key)][-1]
self.check_transforms_match(transform)
return transform

def inverse(self, data: dict) -> dict:
def inverse(self, data: Any) -> Any:
"""
Inverse of ``__call__``.

Expand Down
4 changes: 2 additions & 2 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Mapping, Optional, Tuple, TypeVar, Union

import numpy as np
import torch
Expand Down Expand Up @@ -348,7 +348,7 @@ def __call__(self, data):
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

def key_iterator(self, data: Dict[Hashable, Any], *extra_iterables: Optional[Iterable]) -> Generator:
def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Optional[Iterable]) -> Generator:
"""
Iterate across keys and optionally extra iterables. If key is missing, exception is raised if
`allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped.
Expand Down
47 changes: 24 additions & 23 deletions runtests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,30 @@ then
fi


if [ $doPrecommit = true ]
then
set +e # disable exit on failure so that diagnostics can be given on failure
echo "${separator}${blue}pre-commit${noColor}"

# ensure that the necessary packages for code format testing are installed
if ! is_pip_installed pre_commit
then
install_deps
fi
${cmdPrefix}${PY_EXE} -m pre_commit run --all-files

pre_commit_status=$?
if [ ${pre_commit_status} -ne 0 ]
then
print_style_fail_msg
exit ${pre_commit_status}
else
echo "${green}passed!${noColor}"
fi
set -e # enable exit on failure
fi


if [ $doIsortFormat = true ]
then
set +e # disable exit on failure so that diagnostics can be given on failure
Expand Down Expand Up @@ -500,29 +524,6 @@ then
set -e # enable exit on failure
fi

if [ $doPrecommit = true ]
then
set +e # disable exit on failure so that diagnostics can be given on failure
echo "${separator}${blue}pre-commit${noColor}"

# ensure that the necessary packages for code format testing are installed
if ! is_pip_installed pre_commit
then
install_deps
fi
${cmdPrefix}${PY_EXE} -m pre_commit run --all-files

pre_commit_status=$?
if [ ${pre_commit_status} -ne 0 ]
then
print_style_fail_msg
exit ${pre_commit_status}
else
echo "${green}passed!${noColor}"
fi
set -e # enable exit on failure
fi

if [ $doPylintFormat = true ]
then
set +e # disable exit on failure so that diagnostics can be given on failure
Expand Down