diff --git a/monai/handlers/transform_inverter.py b/monai/handlers/transform_inverter.py index 783a2b8570..a8d5a509df 100644 --- a/monai/handlers/transform_inverter.py +++ b/monai/handlers/transform_inverter.py @@ -10,6 +10,7 @@ # limitations under the License. import warnings +from copy import deepcopy from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union from torch.utils.data import DataLoader as TorchDataLoader @@ -31,7 +32,7 @@ class TransformInverter: """ Ignite handler to automatically invert `transforms`. It takes `engine.state.output` as the input data and uses the transforms information from `engine.state.batch`. - The outputs are stored in `engine.state.output` with the `output_keys`. + The outputs are stored in `engine.state.output` with key: "{output_key}_{postfix}". """ def __init__( @@ -42,7 +43,7 @@ def __init__( batch_keys: Union[str, Sequence[str]] = CommonKeys.IMAGE, meta_key_postfix: str = "meta_dict", collate_fn: Optional[Callable] = no_collation, - postfix: str = "_inverted", + postfix: str = "inverted", nearest_interp: Union[bool, Sequence[bool]] = True, num_workers: Optional[int] = 0, ) -> None: @@ -61,7 +62,7 @@ def __init__( metadata `image_meta_dict` dictionary's `affine` field. collate_fn: how to collate data after inverse transformations. default won't do any collation, so the output will be a list of size batch size. - postfix: will save the inverted result into `ignite.engine.output` with key `{output_key}{postfix}`. + postfix: will save the inverted result into `ignite.engine.output` with key `{output_key}_{postfix}`. nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms, default to `True`. If `False`, use the same interpolation mode as the original transform. it also can be a list of bool, each matches to the `output_keys` data. @@ -104,7 +105,11 @@ def __call__(self, engine: Engine) -> None: transform_info = engine.state.batch[transform_key] if nearest_interp: - convert_inverse_interp_mode(trans_info=transform_info, mode="nearest", align_corners=None) + transform_info = convert_inverse_interp_mode( + trans_info=deepcopy(transform_info), + mode="nearest", + align_corners=None, + ) segs_dict = { batch_key: engine.state.output[output_key].detach().cpu(), @@ -115,5 +120,5 @@ def __call__(self, engine: Engine) -> None: segs_dict[meta_dict_key] = engine.state.batch[meta_dict_key] with allow_missing_keys_mode(self.transform): # type: ignore - inverted_key = f"{output_key}{self.postfix}" + inverted_key = f"{output_key}_{self.postfix}" engine.state.output[inverted_key] = [self._totensor(i[batch_key]) for i in self.inverter(segs_dict)] diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index c08b786e98..3e73beb305 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -800,3 +800,4 @@ def convert_inverse_interp_mode(trans_info: List, mode: str = "nearest", align_c item[InverseKeys.EXTRA_INFO]["align_corners"] = [align_corners_ for _ in range(len(mode))] else: item[InverseKeys.EXTRA_INFO]["align_corners"] = align_corners_ + return trans_info