From 28c9b4bed57f87b02f4e43ac011d61491bef5435 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 22 Jun 2022 13:35:13 +0800 Subject: [PATCH 1/3] [DLMED] update inverse for MetaTensor Signed-off-by: Nic Ma --- monai/transforms/inverse.py | 249 ++++++++++++++++++++++++++-------- monai/transforms/transform.py | 4 +- 2 files changed, 198 insertions(+), 55 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index bdaa2f9b40..0ec45ccfb9 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -8,8 +8,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os -from typing import Hashable, Mapping, Optional, Tuple +import warnings +from contextlib import contextmanager +from typing import Any, Hashable, Mapping, Optional, Tuple import torch @@ -26,13 +29,10 @@ class TraceableTransform(Transform): `trace_key: list of transforms` to each data dictionary. The ``__call__`` method of this transform class must be implemented so - that the transformation information for each key is stored when - ``__call__`` is called. If the transforms were applied to keys "image" and - "label", there will be two extra keys in the dictionary: "image_transforms" - and "label_transforms" (based on `TraceKeys.KEY_SUFFIX`). Each list - contains a list of the transforms applied to that key. + that the transformation information for each key is stored in ``data.applied_operations`` + when ``__call__`` is called. - The information in ``data[key_transform]`` will be compatible with the + The information in ``data.applied_operations`` will be compatible with the default collate since it only stores strings, numbers and arrays. `tracing` could be enabled by `self.set_tracing` or setting @@ -52,41 +52,206 @@ def trace_key(key: Hashable = None): return TraceKeys.KEY_SUFFIX return str(key) + TraceKeys.KEY_SUFFIX - def push_transform( - self, data: Mapping, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None - ) -> None: - """Push to a stack of applied transforms for that key.""" + def get_transform_info( + self, data, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None + ) -> dict: + """ + Return a dictionary with the relevant information pertaining to an applied + transform. - if not self.tracing: - return + Args: + - data: input data. Can be dictionary or MetaTensor. We can use `shape` to + determine the original size of the object (unless that has been given + explicitly, see `orig_size`). + - key: if data is a dictionary, data[key] will be modified + - extra_info: if desired, any extra information pertaining to the applied + transform can be stored in this dictionary. These are often needed for + computing the inverse transformation. + - orig_size: sometimes during the inverse it is useful to know what the size + of the original image was, in which case it can be supplied here. + + Returns: + Dictionary of data pertaining to the applied transformation. + """ info = {TraceKeys.CLASS_NAME: self.__class__.__name__, TraceKeys.ID: id(self)} if orig_size is not None: info[TraceKeys.ORIG_SIZE] = orig_size - elif key in data and hasattr(data[key], "shape"): + elif isinstance(data, Mapping) and key in data and hasattr(data[key], "shape"): info[TraceKeys.ORIG_SIZE] = data[key].shape[1:] + elif hasattr(data, "shape"): + info[TraceKeys.ORIG_SIZE] = data.shape[1:] if extra_info is not None: info[TraceKeys.EXTRA_INFO] = extra_info # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) if hasattr(self, "_do_transform"): # RandomizableTransform info[TraceKeys.DO_TRANSFORM] = self._do_transform # type: ignore + return info - if key in data and isinstance(data[key], MetaTensor): - data[key].push_applied_operation(info) - else: - # If this is the first, create list - if self.trace_key(key) not in data: - if not isinstance(data, dict): - data = dict(data) - data[self.trace_key(key)] = [] - data[self.trace_key(key)].append(info) - - def pop_transform(self, data: Mapping, key: Hashable = None): - """Remove the most recent applied transform.""" + def push_transform( + self, data, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None + ) -> None: + """ + Push to a stack of applied transforms. + + Data can be one of two types: + 1. A `MetaTensor` + 2. A dictionary of data containing arrays/tensors and auxiliary data. In + this case, a key must be supplied (the dictionary-based approach is deprecated). + + If `data` is of type `MetaTensor`, then the applied transform will be added to its internal list. + + If `data` is a dictionary, then one of two things can happen: + 1. If data[key] is a `MetaTensor`, the applied transform will be added to its internal list. + 2. Else, the applied transform will be appended to an adjacent list using + `trace_key`. If, for example, the key is `image`, then the transform + will be appended to `image_transforms`. (This is deprecated.) + + Hopefully it is clear that there are three total possibilities: + 1. data is `MetaTensor` + 2. data is dictionary, data[key] is `MetaTensor` + 3. data is dictionary, data[key] is not `MetaTensor`. + + Args: + - data: dictionary of data or `MetaTensor` + - key: if data is a dictionary, data[key] will be modified + - extra_info: if desired, any extra information pertaining to the applied + transform can be stored in this dictionary. These are often needed for + computing the inverse transformation. + - orig_size: sometimes during the inverse it is useful to know what the size + of the original image was, in which case it can be supplied here. + + Returns: + None, but data has been updated to store the applied transformation. + """ if not self.tracing: return - if key in data and isinstance(data[key], MetaTensor): - return data[key].pop_applied_operation() - return data.get(self.trace_key(key), []).pop() + info = self.get_transform_info(data, key, extra_info, orig_size) + + if isinstance(data, MetaTensor): + data.push_applied_operation(info) + elif isinstance(data, Mapping): + if key in data and isinstance(data[key], MetaTensor): + data[key].push_applied_operation(info) + else: + # If this is the first, create list + if self.trace_key(key) not in data: + if not isinstance(data, dict): + data = dict(data) + data[self.trace_key(key)] = [] + data[self.trace_key(key)].append(info) + else: + warnings.warn(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}. {info} not tracked.") + + def check_transforms_match(self, transform: Mapping) -> None: + """Check transforms are of same instance.""" + xform_id = transform.get(TraceKeys.ID, "") + if xform_id == id(self): + return + # TraceKeys.NONE to skip the id check + if xform_id == TraceKeys.NONE: + return + xform_name = transform.get(TraceKeys.CLASS_NAME, "") + # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID) + if torch.multiprocessing.get_start_method() in ("spawn", None) and xform_name == self.__class__.__name__: + return + raise RuntimeError( + f"Error {self.__class__.__name__} getting the most recently " + f"applied invertible transform {xform_name} {xform_id} != {id(self)}." + ) + + def get_most_recent_transform(self, data, key: Hashable = None, check: bool = True, pop: bool = False): + """ + Get most recent transform. + + Data can be one of two things: + 1. A `MetaTensor` + 2. A dictionary of data containing arrays/tensors and auxiliary data. In + this case, a key must be supplied (the dictionary-based approach is deprecated). + + If `data` is of type `MetaTensor`, then the applied transform will be added to its internal list. + + If `data` is a dictionary, then one of two things can happen: + 1. If data[key] is a `MetaTensor`, the applied transform will be added to its internal list. + 2. Else, the applied transform will be appended to an adjacent list using + `trace_key`. If, for example, the key is `image`, then the transform + will be appended to `image_transforms`. (This is deprecated.) + + Hopefully it is clear that there are three total possibilities: + 1. data is `MetaTensor` + 2. data is dictionary, data[key] is `MetaTensor` + 3. data is dictionary, data[key] is not `MetaTensor`. + + Args: + - data: dictionary of data or `MetaTensor` + - key: if data is a dictionary, data[key] will be modified + - check: if true, check that `self` is the same type as the most recently-applied transform. + - pop: if true, remove the transform as it is returned. + + Returns: + Dictionary of most recently applied transform + + Raises: + - RuntimeError: data is neither `MetaTensor` nor dictionary + """ + if not self.tracing: + raise RuntimeError("Transform Tracing must be enabled to get the most recent transform.") + if isinstance(data, MetaTensor): + all_transforms = data.applied_operations + elif isinstance(data, Mapping): + if key in data and isinstance(data[key], MetaTensor): + all_transforms = data[key].applied_operations + else: + all_transforms = data[self.trace_key(key)] + else: + raise ValueError(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}.") + if check: + self.check_transforms_match(all_transforms[-1]) + return all_transforms.pop() if pop else all_transforms[-1] + + def pop_transform(self, data, key: Hashable = None, check: bool = True): + """ + Return and pop the most recent transform. + + Data can be one of two things: + 1. A `MetaTensor` + 2. A dictionary of data containing arrays/tensors and auxilliary data. In + this case, a key must be supplied. + + If `data` is of type `MetaTensor`, then the applied transform will be added to + its internal list. + + If `data` is a dictionary, then one of two things can happen: + 1. If data[key] is a `MetaTensor`, the applied transform will be added to + its internal list. + 2. Else, the applied transform will be appended to an adjacent list using + `trace_key`. If, for example, the key is `image`, then the transform + will be appended to `image_transforms`. + + Hopefully it is clear that there are three total possibilities: + 1. data is `MetaTensor` + 2. data is dictionary, data[key] is `MetaTensor` + 3. data is dictionary, data[key] is not `MetaTensor`. + + Args: + - data: dictionary of data or `MetaTensor` + - key: if data is a dictionary, data[key] will be modified + - check: if true, check that `self` is the same type as the most recently-applied transform. + + Returns: + Dictionary of most recently applied transform + + Raises: + - RuntimeError: data is neither `MetaTensor` nor dictionary + """ + return self.get_most_recent_transform(data, key, check, pop=True) + + @contextmanager + def trace_transform(self, to_trace: bool): + """Temporarily set the tracing status of a transform with a context manager.""" + prev = self.tracing + self.tracing = to_trace + yield + self.tracing = prev class InvertibleTransform(TraceableTransform): @@ -103,7 +268,7 @@ class InvertibleTransform(TraceableTransform): different parameters being passed to each label (e.g., different interpolation for image and label). - - the inverse transforms are applied in a last- in-first-out order. As + - the inverse transforms are applied in a last-in-first-out order. As the inverse is applied, its entry is removed from the list detailing the applied transformations. That is to say that during the forward pass, the list of applied transforms grows, and then during the @@ -126,29 +291,7 @@ class InvertibleTransform(TraceableTransform): """ - def check_transforms_match(self, transform: Mapping) -> None: - """Check transforms are of same instance.""" - xform_name = transform.get(TraceKeys.CLASS_NAME, "") - xform_id = transform.get(TraceKeys.ID, "") - if xform_id == id(self): - return - # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID) - if torch.multiprocessing.get_start_method() in ("spawn", None) and xform_name == self.__class__.__name__: - return - raise RuntimeError(f"Error inverting the most recently applied invertible transform {xform_name} {xform_id}.") - - def get_most_recent_transform(self, data: Mapping, key: Hashable = None): - """Get most recent transform.""" - if not self.tracing: - raise RuntimeError("Transform Tracing must be enabled to get the most recent transform.") - if isinstance(data[key], MetaTensor): - transform = data[key].applied_operations[-1] - else: - transform = data[self.trace_key(key)][-1] - self.check_transforms_match(transform) - return transform - - def inverse(self, data: dict) -> dict: + def inverse(self, data: Any) -> Any: """ Inverse of ``__call__``. diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 65bb13e6b8..5819d2971d 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -14,7 +14,7 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Mapping, Optional, Tuple, TypeVar, Union import numpy as np import torch @@ -348,7 +348,7 @@ def __call__(self, data): """ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - def key_iterator(self, data: Dict[Hashable, Any], *extra_iterables: Optional[Iterable]) -> Generator: + def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Optional[Iterable]) -> Generator: """ Iterate across keys and optionally extra iterables. If key is missing, exception is raised if `allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped. From b66a60f751160af90ec38e207d84e0732818bcb6 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 22 Jun 2022 06:53:18 +0100 Subject: [PATCH 2/3] fixes unit test Signed-off-by: Wenqi Li --- monai/transforms/inverse.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 0ec45ccfb9..b56bc0f5ea 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -49,8 +49,8 @@ def set_tracing(self, tracing: bool) -> None: def trace_key(key: Hashable = None): """The key to store the stack of applied transforms.""" if key is None: - return TraceKeys.KEY_SUFFIX - return str(key) + TraceKeys.KEY_SUFFIX + return f"{TraceKeys.KEY_SUFFIX}" + return f"{key}{TraceKeys.KEY_SUFFIX}" def get_transform_info( self, data, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None @@ -201,7 +201,7 @@ def get_most_recent_transform(self, data, key: Hashable = None, check: bool = Tr if key in data and isinstance(data[key], MetaTensor): all_transforms = data[key].applied_operations else: - all_transforms = data[self.trace_key(key)] + all_transforms = data.get(self.trace_key(key), MetaTensor.get_default_applied_operations()) else: raise ValueError(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}.") if check: From d8177f005308999e1898ac0848e7725f3f84f770 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 22 Jun 2022 07:12:37 +0100 Subject: [PATCH 3/3] update docs Signed-off-by: Wenqi Li --- monai/transforms/inverse.py | 119 ++++++++++++------------------------ runtests.sh | 47 +++++++------- 2 files changed, 63 insertions(+), 103 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index b56bc0f5ea..41a02989db 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -25,15 +25,31 @@ class TraceableTransform(Transform): """ - Maintains a stack of applied transforms. The stack is inserted as pairs of - `trace_key: list of transforms` to each data dictionary. + Maintains a stack of applied transforms to data. + + Data can be one of two types: + 1. A `MetaTensor` (this is the preferred data type). + 2. A dictionary of data containing arrays/tensors and auxiliary metadata. In + this case, a key must be supplied (this dictionary-based approach is deprecated). + + If `data` is of type `MetaTensor`, then the applied transform will be added to ``data.applied_operations``. + + If `data` is a dictionary, then one of two things can happen: + 1. If data[key] is a `MetaTensor`, the applied transform will be added to ``data[key].applied_operations``. + 2. Else, the applied transform will be appended to an adjacent list using + `trace_key`. If, for example, the key is `image`, then the transform + will be appended to `image_transforms` (this dictionary-based approach is deprecated). + + Hopefully it is clear that there are three total possibilities: + 1. data is `MetaTensor` + 2. data is dictionary, data[key] is `MetaTensor` + 3. data is dictionary, data[key] is not `MetaTensor` (this is a deprecated approach). The ``__call__`` method of this transform class must be implemented so - that the transformation information for each key is stored in ``data.applied_operations`` - when ``__call__`` is called. + that the transformation information is stored during the data transformation. - The information in ``data.applied_operations`` will be compatible with the - default collate since it only stores strings, numbers and arrays. + The information in the stack of applied transforms must be compatible with the + default collate, by only storing strings, numbers and arrays. `tracing` could be enabled by `self.set_tracing` or setting `MONAI_TRACE_TRANSFORM` when initializing the class. @@ -56,18 +72,17 @@ def get_transform_info( self, data, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None ) -> dict: """ - Return a dictionary with the relevant information pertaining to an applied - transform. + Return a dictionary with the relevant information pertaining to an applied transform. Args: - - data: input data. Can be dictionary or MetaTensor. We can use `shape` to + data: input data. Can be dictionary or MetaTensor. We can use `shape` to determine the original size of the object (unless that has been given explicitly, see `orig_size`). - - key: if data is a dictionary, data[key] will be modified - - extra_info: if desired, any extra information pertaining to the applied + key: if data is a dictionary, data[key] will be modified. + extra_info: if desired, any extra information pertaining to the applied transform can be stored in this dictionary. These are often needed for computing the inverse transformation. - - orig_size: sometimes during the inverse it is useful to know what the size + orig_size: sometimes during the inverse it is useful to know what the size of the original image was, in which case it can be supplied here. Returns: @@ -93,31 +108,13 @@ def push_transform( """ Push to a stack of applied transforms. - Data can be one of two types: - 1. A `MetaTensor` - 2. A dictionary of data containing arrays/tensors and auxiliary data. In - this case, a key must be supplied (the dictionary-based approach is deprecated). - - If `data` is of type `MetaTensor`, then the applied transform will be added to its internal list. - - If `data` is a dictionary, then one of two things can happen: - 1. If data[key] is a `MetaTensor`, the applied transform will be added to its internal list. - 2. Else, the applied transform will be appended to an adjacent list using - `trace_key`. If, for example, the key is `image`, then the transform - will be appended to `image_transforms`. (This is deprecated.) - - Hopefully it is clear that there are three total possibilities: - 1. data is `MetaTensor` - 2. data is dictionary, data[key] is `MetaTensor` - 3. data is dictionary, data[key] is not `MetaTensor`. - Args: - - data: dictionary of data or `MetaTensor` - - key: if data is a dictionary, data[key] will be modified - - extra_info: if desired, any extra information pertaining to the applied + data: dictionary of data or `MetaTensor`. + key: if data is a dictionary, data[key] will be modified. + extra_info: if desired, any extra information pertaining to the applied transform can be stored in this dictionary. These are often needed for computing the inverse transformation. - - orig_size: sometimes during the inverse it is useful to know what the size + orig_size: sometimes during the inverse it is useful to know what the size of the original image was, in which case it can be supplied here. Returns: @@ -161,31 +158,13 @@ def check_transforms_match(self, transform: Mapping) -> None: def get_most_recent_transform(self, data, key: Hashable = None, check: bool = True, pop: bool = False): """ - Get most recent transform. - - Data can be one of two things: - 1. A `MetaTensor` - 2. A dictionary of data containing arrays/tensors and auxiliary data. In - this case, a key must be supplied (the dictionary-based approach is deprecated). - - If `data` is of type `MetaTensor`, then the applied transform will be added to its internal list. - - If `data` is a dictionary, then one of two things can happen: - 1. If data[key] is a `MetaTensor`, the applied transform will be added to its internal list. - 2. Else, the applied transform will be appended to an adjacent list using - `trace_key`. If, for example, the key is `image`, then the transform - will be appended to `image_transforms`. (This is deprecated.) - - Hopefully it is clear that there are three total possibilities: - 1. data is `MetaTensor` - 2. data is dictionary, data[key] is `MetaTensor` - 3. data is dictionary, data[key] is not `MetaTensor`. + Get most recent transform for the stack. Args: - - data: dictionary of data or `MetaTensor` - - key: if data is a dictionary, data[key] will be modified - - check: if true, check that `self` is the same type as the most recently-applied transform. - - pop: if true, remove the transform as it is returned. + data: dictionary of data or `MetaTensor`. + key: if data is a dictionary, data[key] will be modified. + check: if true, check that `self` is the same type as the most recently-applied transform. + pop: if true, remove the transform as it is returned. Returns: Dictionary of most recently applied transform @@ -212,30 +191,10 @@ def pop_transform(self, data, key: Hashable = None, check: bool = True): """ Return and pop the most recent transform. - Data can be one of two things: - 1. A `MetaTensor` - 2. A dictionary of data containing arrays/tensors and auxilliary data. In - this case, a key must be supplied. - - If `data` is of type `MetaTensor`, then the applied transform will be added to - its internal list. - - If `data` is a dictionary, then one of two things can happen: - 1. If data[key] is a `MetaTensor`, the applied transform will be added to - its internal list. - 2. Else, the applied transform will be appended to an adjacent list using - `trace_key`. If, for example, the key is `image`, then the transform - will be appended to `image_transforms`. - - Hopefully it is clear that there are three total possibilities: - 1. data is `MetaTensor` - 2. data is dictionary, data[key] is `MetaTensor` - 3. data is dictionary, data[key] is not `MetaTensor`. - Args: - - data: dictionary of data or `MetaTensor` - - key: if data is a dictionary, data[key] will be modified - - check: if true, check that `self` is the same type as the most recently-applied transform. + data: dictionary of data or `MetaTensor` + key: if data is a dictionary, data[key] will be modified + check: if true, check that `self` is the same type as the most recently-applied transform. Returns: Dictionary of most recently applied transform diff --git a/runtests.sh b/runtests.sh index a632e2664f..9e6ef3d0e1 100755 --- a/runtests.sh +++ b/runtests.sh @@ -403,6 +403,30 @@ then fi +if [ $doPrecommit = true ] +then + set +e # disable exit on failure so that diagnostics can be given on failure + echo "${separator}${blue}pre-commit${noColor}" + + # ensure that the necessary packages for code format testing are installed + if ! is_pip_installed pre_commit + then + install_deps + fi + ${cmdPrefix}${PY_EXE} -m pre_commit run --all-files + + pre_commit_status=$? + if [ ${pre_commit_status} -ne 0 ] + then + print_style_fail_msg + exit ${pre_commit_status} + else + echo "${green}passed!${noColor}" + fi + set -e # enable exit on failure +fi + + if [ $doIsortFormat = true ] then set +e # disable exit on failure so that diagnostics can be given on failure @@ -500,29 +524,6 @@ then set -e # enable exit on failure fi -if [ $doPrecommit = true ] -then - set +e # disable exit on failure so that diagnostics can be given on failure - echo "${separator}${blue}pre-commit${noColor}" - - # ensure that the necessary packages for code format testing are installed - if ! is_pip_installed pre_commit - then - install_deps - fi - ${cmdPrefix}${PY_EXE} -m pre_commit run --all-files - - pre_commit_status=$? - if [ ${pre_commit_status} -ne 0 ] - then - print_style_fail_msg - exit ${pre_commit_status} - else - echo "${green}passed!${noColor}" - fi - set -e # enable exit on failure -fi - if [ $doPylintFormat = true ] then set +e # disable exit on failure so that diagnostics can be given on failure