diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index bdaa2f9b40..41a02989db 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -8,8 +8,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os -from typing import Hashable, Mapping, Optional, Tuple +import warnings +from contextlib import contextmanager +from typing import Any, Hashable, Mapping, Optional, Tuple import torch @@ -22,18 +25,31 @@ class TraceableTransform(Transform): """ - Maintains a stack of applied transforms. The stack is inserted as pairs of - `trace_key: list of transforms` to each data dictionary. + Maintains a stack of applied transforms to data. + + Data can be one of two types: + 1. A `MetaTensor` (this is the preferred data type). + 2. A dictionary of data containing arrays/tensors and auxiliary metadata. In + this case, a key must be supplied (this dictionary-based approach is deprecated). + + If `data` is of type `MetaTensor`, then the applied transform will be added to ``data.applied_operations``. + + If `data` is a dictionary, then one of two things can happen: + 1. If data[key] is a `MetaTensor`, the applied transform will be added to ``data[key].applied_operations``. + 2. Else, the applied transform will be appended to an adjacent list using + `trace_key`. If, for example, the key is `image`, then the transform + will be appended to `image_transforms` (this dictionary-based approach is deprecated). + + Hopefully it is clear that there are three total possibilities: + 1. data is `MetaTensor` + 2. data is dictionary, data[key] is `MetaTensor` + 3. data is dictionary, data[key] is not `MetaTensor` (this is a deprecated approach). The ``__call__`` method of this transform class must be implemented so - that the transformation information for each key is stored when - ``__call__`` is called. If the transforms were applied to keys "image" and - "label", there will be two extra keys in the dictionary: "image_transforms" - and "label_transforms" (based on `TraceKeys.KEY_SUFFIX`). Each list - contains a list of the transforms applied to that key. + that the transformation information is stored during the data transformation. - The information in ``data[key_transform]`` will be compatible with the - default collate since it only stores strings, numbers and arrays. + The information in the stack of applied transforms must be compatible with the + default collate, by only storing strings, numbers and arrays. `tracing` could be enabled by `self.set_tracing` or setting `MONAI_TRACE_TRANSFORM` when initializing the class. @@ -49,44 +65,152 @@ def set_tracing(self, tracing: bool) -> None: def trace_key(key: Hashable = None): """The key to store the stack of applied transforms.""" if key is None: - return TraceKeys.KEY_SUFFIX - return str(key) + TraceKeys.KEY_SUFFIX + return f"{TraceKeys.KEY_SUFFIX}" + return f"{key}{TraceKeys.KEY_SUFFIX}" - def push_transform( - self, data: Mapping, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None - ) -> None: - """Push to a stack of applied transforms for that key.""" + def get_transform_info( + self, data, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None + ) -> dict: + """ + Return a dictionary with the relevant information pertaining to an applied transform. - if not self.tracing: - return + Args: + data: input data. Can be dictionary or MetaTensor. We can use `shape` to + determine the original size of the object (unless that has been given + explicitly, see `orig_size`). + key: if data is a dictionary, data[key] will be modified. + extra_info: if desired, any extra information pertaining to the applied + transform can be stored in this dictionary. These are often needed for + computing the inverse transformation. + orig_size: sometimes during the inverse it is useful to know what the size + of the original image was, in which case it can be supplied here. + + Returns: + Dictionary of data pertaining to the applied transformation. + """ info = {TraceKeys.CLASS_NAME: self.__class__.__name__, TraceKeys.ID: id(self)} if orig_size is not None: info[TraceKeys.ORIG_SIZE] = orig_size - elif key in data and hasattr(data[key], "shape"): + elif isinstance(data, Mapping) and key in data and hasattr(data[key], "shape"): info[TraceKeys.ORIG_SIZE] = data[key].shape[1:] + elif hasattr(data, "shape"): + info[TraceKeys.ORIG_SIZE] = data.shape[1:] if extra_info is not None: info[TraceKeys.EXTRA_INFO] = extra_info # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) if hasattr(self, "_do_transform"): # RandomizableTransform info[TraceKeys.DO_TRANSFORM] = self._do_transform # type: ignore + return info - if key in data and isinstance(data[key], MetaTensor): - data[key].push_applied_operation(info) - else: - # If this is the first, create list - if self.trace_key(key) not in data: - if not isinstance(data, dict): - data = dict(data) - data[self.trace_key(key)] = [] - data[self.trace_key(key)].append(info) - - def pop_transform(self, data: Mapping, key: Hashable = None): - """Remove the most recent applied transform.""" + def push_transform( + self, data, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None + ) -> None: + """ + Push to a stack of applied transforms. + + Args: + data: dictionary of data or `MetaTensor`. + key: if data is a dictionary, data[key] will be modified. + extra_info: if desired, any extra information pertaining to the applied + transform can be stored in this dictionary. These are often needed for + computing the inverse transformation. + orig_size: sometimes during the inverse it is useful to know what the size + of the original image was, in which case it can be supplied here. + + Returns: + None, but data has been updated to store the applied transformation. + """ if not self.tracing: return - if key in data and isinstance(data[key], MetaTensor): - return data[key].pop_applied_operation() - return data.get(self.trace_key(key), []).pop() + info = self.get_transform_info(data, key, extra_info, orig_size) + + if isinstance(data, MetaTensor): + data.push_applied_operation(info) + elif isinstance(data, Mapping): + if key in data and isinstance(data[key], MetaTensor): + data[key].push_applied_operation(info) + else: + # If this is the first, create list + if self.trace_key(key) not in data: + if not isinstance(data, dict): + data = dict(data) + data[self.trace_key(key)] = [] + data[self.trace_key(key)].append(info) + else: + warnings.warn(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}. {info} not tracked.") + + def check_transforms_match(self, transform: Mapping) -> None: + """Check transforms are of same instance.""" + xform_id = transform.get(TraceKeys.ID, "") + if xform_id == id(self): + return + # TraceKeys.NONE to skip the id check + if xform_id == TraceKeys.NONE: + return + xform_name = transform.get(TraceKeys.CLASS_NAME, "") + # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID) + if torch.multiprocessing.get_start_method() in ("spawn", None) and xform_name == self.__class__.__name__: + return + raise RuntimeError( + f"Error {self.__class__.__name__} getting the most recently " + f"applied invertible transform {xform_name} {xform_id} != {id(self)}." + ) + + def get_most_recent_transform(self, data, key: Hashable = None, check: bool = True, pop: bool = False): + """ + Get most recent transform for the stack. + + Args: + data: dictionary of data or `MetaTensor`. + key: if data is a dictionary, data[key] will be modified. + check: if true, check that `self` is the same type as the most recently-applied transform. + pop: if true, remove the transform as it is returned. + + Returns: + Dictionary of most recently applied transform + + Raises: + - RuntimeError: data is neither `MetaTensor` nor dictionary + """ + if not self.tracing: + raise RuntimeError("Transform Tracing must be enabled to get the most recent transform.") + if isinstance(data, MetaTensor): + all_transforms = data.applied_operations + elif isinstance(data, Mapping): + if key in data and isinstance(data[key], MetaTensor): + all_transforms = data[key].applied_operations + else: + all_transforms = data.get(self.trace_key(key), MetaTensor.get_default_applied_operations()) + else: + raise ValueError(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}.") + if check: + self.check_transforms_match(all_transforms[-1]) + return all_transforms.pop() if pop else all_transforms[-1] + + def pop_transform(self, data, key: Hashable = None, check: bool = True): + """ + Return and pop the most recent transform. + + Args: + data: dictionary of data or `MetaTensor` + key: if data is a dictionary, data[key] will be modified + check: if true, check that `self` is the same type as the most recently-applied transform. + + Returns: + Dictionary of most recently applied transform + + Raises: + - RuntimeError: data is neither `MetaTensor` nor dictionary + """ + return self.get_most_recent_transform(data, key, check, pop=True) + + @contextmanager + def trace_transform(self, to_trace: bool): + """Temporarily set the tracing status of a transform with a context manager.""" + prev = self.tracing + self.tracing = to_trace + yield + self.tracing = prev class InvertibleTransform(TraceableTransform): @@ -103,7 +227,7 @@ class InvertibleTransform(TraceableTransform): different parameters being passed to each label (e.g., different interpolation for image and label). - - the inverse transforms are applied in a last- in-first-out order. As + - the inverse transforms are applied in a last-in-first-out order. As the inverse is applied, its entry is removed from the list detailing the applied transformations. That is to say that during the forward pass, the list of applied transforms grows, and then during the @@ -126,29 +250,7 @@ class InvertibleTransform(TraceableTransform): """ - def check_transforms_match(self, transform: Mapping) -> None: - """Check transforms are of same instance.""" - xform_name = transform.get(TraceKeys.CLASS_NAME, "") - xform_id = transform.get(TraceKeys.ID, "") - if xform_id == id(self): - return - # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID) - if torch.multiprocessing.get_start_method() in ("spawn", None) and xform_name == self.__class__.__name__: - return - raise RuntimeError(f"Error inverting the most recently applied invertible transform {xform_name} {xform_id}.") - - def get_most_recent_transform(self, data: Mapping, key: Hashable = None): - """Get most recent transform.""" - if not self.tracing: - raise RuntimeError("Transform Tracing must be enabled to get the most recent transform.") - if isinstance(data[key], MetaTensor): - transform = data[key].applied_operations[-1] - else: - transform = data[self.trace_key(key)][-1] - self.check_transforms_match(transform) - return transform - - def inverse(self, data: dict) -> dict: + def inverse(self, data: Any) -> Any: """ Inverse of ``__call__``. diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 65bb13e6b8..5819d2971d 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -14,7 +14,7 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Mapping, Optional, Tuple, TypeVar, Union import numpy as np import torch @@ -348,7 +348,7 @@ def __call__(self, data): """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def key_iterator(self, data: Dict[Hashable, Any], *extra_iterables: Optional[Iterable]) -> Generator: + def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Optional[Iterable]) -> Generator: """ Iterate across keys and optionally extra iterables. If key is missing, exception is raised if `allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped. diff --git a/runtests.sh b/runtests.sh index a632e2664f..9e6ef3d0e1 100755 --- a/runtests.sh +++ b/runtests.sh @@ -403,6 +403,30 @@ then fi +if [ $doPrecommit = true ] +then + set +e # disable exit on failure so that diagnostics can be given on failure + echo "${separator}${blue}pre-commit${noColor}" + + # ensure that the necessary packages for code format testing are installed + if ! is_pip_installed pre_commit + then + install_deps + fi + ${cmdPrefix}${PY_EXE} -m pre_commit run --all-files + + pre_commit_status=$? + if [ ${pre_commit_status} -ne 0 ] + then + print_style_fail_msg + exit ${pre_commit_status} + else + echo "${green}passed!${noColor}" + fi + set -e # enable exit on failure +fi + + if [ $doIsortFormat = true ] then set +e # disable exit on failure so that diagnostics can be given on failure @@ -500,29 +524,6 @@ then set -e # enable exit on failure fi -if [ $doPrecommit = true ] -then - set +e # disable exit on failure so that diagnostics can be given on failure - echo "${separator}${blue}pre-commit${noColor}" - - # ensure that the necessary packages for code format testing are installed - if ! is_pip_installed pre_commit - then - install_deps - fi - ${cmdPrefix}${PY_EXE} -m pre_commit run --all-files - - pre_commit_status=$? - if [ ${pre_commit_status} -ne 0 ] - then - print_style_fail_msg - exit ${pre_commit_status} - else - echo "${green}passed!${noColor}" - fi - set -e # enable exit on failure -fi - if [ $doPylintFormat = true ] then set +e # disable exit on failure so that diagnostics can be given on failure