diff --git a/docs/source/data.rst b/docs/source/data.rst index 66fadd549b..7d0ffbd7b1 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -185,10 +185,6 @@ ThreadBuffer .. autoclass:: monai.data.ThreadBuffer -BatchInverseTransform -~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: monai.data.BatchInverseTransform - TestTimeAugmentation ~~~~~~~~~~~~~~~~~~~~ .. autoclass:: monai.data.TestTimeAugmentation diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index ccdb5898a0..64697b9c1f 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -43,6 +43,11 @@ Generic Interfaces .. autoclass:: InvertibleTransform :members: +`BatchInverseTransform` +^^^^^^^^^^^^^^^^^^^^^^^ +.. autoclass:: BatchInverseTransform + :members: + Vanilla Transforms ------------------ @@ -836,6 +841,12 @@ Post-processing (Dict) :members: :special-members: __call__ +`Invertd` +""""""""" +.. autoclass:: Invertd + :members: + :special-members: __call__ + Spatial (Dict) ^^^^^^^^^^^^^^ diff --git a/monai/data/__init__.py b/monai/data/__init__.py index adb27a608e..e2eec0ef12 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -26,7 +26,6 @@ from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader -from .inverse_batch_transform import BatchInverseTransform from .iterable_dataset import IterableDataset from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 9e4497d5e4..49b6c774e2 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -16,10 +16,10 @@ from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset -from monai.data.inverse_batch_transform import BatchInverseTransform from monai.data.utils import list_data_collate, pad_list_data_collate from monai.transforms.compose import Compose from monai.transforms.inverse import InvertibleTransform +from monai.transforms.inverse_batch_transform import BatchInverseTransform from monai.transforms.transform import Randomizable from monai.transforms.utils import allow_missing_keys_mode from monai.utils.enums import CommonKeys, InverseKeys diff --git a/monai/handlers/transform_inverter.py b/monai/handlers/transform_inverter.py index d1e79c389b..a28a12dc0f 100644 --- a/monai/handlers/transform_inverter.py +++ b/monai/handlers/transform_inverter.py @@ -9,18 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings -from copy import deepcopy from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union import torch from torch.utils.data import DataLoader as TorchDataLoader -from monai.data import BatchInverseTransform from monai.data.utils import no_collation from monai.engines.utils import CommonKeys, IterationEvents -from monai.transforms import InvertibleTransform, ToTensor, allow_missing_keys_mode, convert_inverse_interp_mode -from monai.utils import InverseKeys, ensure_tuple, ensure_tuple_rep, exact_version, optional_import +from monai.transforms import Invertd, InvertibleTransform +from monai.utils import ensure_tuple, exact_version, optional_import Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") if TYPE_CHECKING: @@ -33,6 +30,7 @@ class TransformInverter: """ Ignite handler to automatically invert `transforms`. It takes `engine.state.output` as the input data and uses the transforms information from `engine.state.batch`. + Expect both `engine.state.output` and `engine.state.batch` to be dictionary data. The inverted data are stored in `engine.state.output` with key: "{output_key}_{postfix}". And the inverted meta dict will be stored in `engine.state.batch` with key: "{output_key}_{postfix}_{meta_key_postfix}". @@ -85,22 +83,23 @@ def __init__( Set to `None`, to use the `num_workers` of the input transform data loader. """ - self.transform = transform - self.inverter = BatchInverseTransform( + self.inverter = Invertd( + keys=output_keys, transform=transform, loader=loader, + orig_keys=batch_keys, + meta_key_postfix=meta_key_postfix, collate_fn=collate_fn, + postfix=postfix, + nearest_interp=nearest_interp, + to_tensor=to_tensor, + device=device, + post_func=post_func, num_workers=num_workers, ) self.output_keys = ensure_tuple(output_keys) - self.batch_keys = ensure_tuple_rep(batch_keys, len(self.output_keys)) self.meta_key_postfix = meta_key_postfix self.postfix = postfix - self.nearest_interp = ensure_tuple_rep(nearest_interp, len(self.output_keys)) - self.to_tensor = ensure_tuple_rep(to_tensor, len(self.output_keys)) - self.device = ensure_tuple_rep(device, len(self.output_keys)) - self.post_func = ensure_tuple_rep(post_func, len(self.output_keys)) - self._totensor = ToTensor() def attach(self, engine: Engine) -> None: """ @@ -114,42 +113,17 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - for output_key, batch_key, nearest_interp, to_tensor, device, post_func in zip( - self.output_keys, self.batch_keys, self.nearest_interp, self.to_tensor, self.device, self.post_func - ): - transform_key = batch_key + InverseKeys.KEY_SUFFIX - if transform_key not in engine.state.batch: - warnings.warn(f"all the transforms on `{batch_key}` are not InvertibleTransform.") - continue - - transform_info = engine.state.batch[transform_key] - if nearest_interp: - transform_info = convert_inverse_interp_mode( - trans_info=deepcopy(transform_info), - mode="nearest", - align_corners=None, - ) - - output = engine.state.output[output_key] - if isinstance(output, torch.Tensor): - output = output.detach() - segs_dict = { - batch_key: output, - transform_key: transform_info, - } - meta_dict_key = f"{batch_key}_{self.meta_key_postfix}" - if meta_dict_key in engine.state.batch: - segs_dict[meta_dict_key] = engine.state.batch[meta_dict_key] - - with allow_missing_keys_mode(self.transform): # type: ignore - inverted = self.inverter(segs_dict) + # combine `batch` and `output` to temporarily act as 1 dict for post transform + data = dict(engine.state.batch) + data.update(engine.state.output) + ret = self.inverter(data) + for output_key in self.output_keys: # save the inverted data into state.output inverted_key = f"{output_key}_{self.postfix}" - engine.state.output[inverted_key] = [ - post_func(self._totensor(i[batch_key]).to(device) if to_tensor else i[batch_key]) for i in inverted - ] - + if inverted_key in ret: + engine.state.output[inverted_key] = ret[inverted_key] # save the inverted meta dict into state.batch - if meta_dict_key in engine.state.batch: - engine.state.batch[f"{inverted_key}_{self.meta_key_postfix}"] = [i.get(meta_dict_key) for i in inverted] + meta_dict_key = f"{inverted_key}_{self.meta_key_postfix}" + if meta_dict_key in ret: + engine.state.batch[meta_dict_key] = ret[meta_dict_key] diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index f7247a8886..0d534c40ca 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -156,6 +156,7 @@ ThresholdIntensityDict, ) from .inverse import InvertibleTransform +from .inverse_batch_transform import BatchInverseTransform from .io.array import LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .post.array import ( @@ -178,6 +179,9 @@ DecollateD, DecollateDict, Ensembled, + Invertd, + InvertD, + InvertDict, KeepLargestConnectedComponentd, KeepLargestConnectedComponentD, KeepLargestConnectedComponentDict, diff --git a/monai/data/inverse_batch_transform.py b/monai/transforms/inverse_batch_transform.py similarity index 94% rename from monai/data/inverse_batch_transform.py rename to monai/transforms/inverse_batch_transform.py index cee1b3bcb2..724169c0d8 100644 --- a/monai/data/inverse_batch_transform.py +++ b/monai/transforms/inverse_batch_transform.py @@ -10,13 +10,12 @@ # limitations under the License. import warnings -from typing import Any, Callable, Dict, Hashable, Optional, Sequence +from typing import Any, Callable, Dict, Optional, Sequence -import numpy as np +from torch.utils.data import Dataset from torch.utils.data.dataloader import DataLoader as TorchDataLoader from monai.data.dataloader import DataLoader -from monai.data.dataset import Dataset from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate from monai.transforms.croppad.batch import PadListDataCollate from monai.transforms.inverse import InvertibleTransform @@ -33,11 +32,11 @@ def __init__( transform: InvertibleTransform, pad_collation_used: bool, ) -> None: - super().__init__(data, transform) + self.data = data self.invertible_transform = transform self.pad_collation_used = pad_collation_used - def _transform(self, index: int) -> Dict[Hashable, np.ndarray]: + def __getitem__(self, index: int): data = dict(self.data[index]) # If pad collation was used, then we need to undo this first if self.pad_collation_used: @@ -48,6 +47,9 @@ def _transform(self, index: int) -> Dict[Hashable, np.ndarray]: return data return self.invertible_transform.inverse(data) + def __len__(self) -> int: + return len(self.data) + class BatchInverseTransform(Transform): """ diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 13b709e284..2621b3f4a6 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -15,13 +15,19 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Union +import warnings +from copy import deepcopy +from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Union import numpy as np import torch +from torch.utils.data import DataLoader as TorchDataLoader import monai.data from monai.config import KeysCollection +from monai.data.utils import no_collation +from monai.transforms.inverse import InvertibleTransform +from monai.transforms.inverse_batch_transform import BatchInverseTransform from monai.transforms.post.array import ( Activations, AsDiscrete, @@ -32,7 +38,10 @@ VoteEnsemble, ) from monai.transforms.transform import MapTransform +from monai.transforms.utility.array import ToTensor +from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode from monai.utils import ensure_tuple_rep +from monai.utils.enums import InverseKeys __all__ = [ "Activationsd", @@ -46,6 +55,9 @@ "ActivationsDict", "AsDiscreteD", "AsDiscreteDict", + "InvertD", + "InvertDict", + "Invertd", "KeepLargestConnectedComponentD", "KeepLargestConnectedComponentDict", "LabelToContourD", @@ -399,6 +411,131 @@ def __call__(self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]): return d +class Invertd(MapTransform): + """ + Utility transform to automatically invert the previous transforms based on transform information. + It extracts the transform information applied on the data with `orig_keys`, then inverts the transforms + on the data with `keys`. several `keys` can share one `orig_keys`. + A typical usage is to invert the pre-transforms (appplied on input `image`) on the model `pred` data. + + Note: + As this transform only accepts 1 input dict while ignite stores model input data in `state.batch` + and stores model output data in `state.output`, so it's not compatible with MONAI engines so far. + For MONAI workflow engines, please use the `TransformInverter` handler instead. + Users can use this transform in a regular PyTorch program which uses dict data for transforms. + + """ + + def __init__( + self, + keys: KeysCollection, + transform: InvertibleTransform, + loader: TorchDataLoader, + orig_keys: Union[str, Sequence[str]], + meta_key_postfix: str = "meta_dict", + collate_fn: Optional[Callable] = no_collation, + postfix: str = "inverted", + nearest_interp: Union[bool, Sequence[bool]] = True, + to_tensor: Union[bool, Sequence[bool]] = True, + device: Union[Union[str, torch.device], Sequence[Union[str, torch.device]]] = "cpu", + post_func: Union[Callable, Sequence[Callable]] = lambda x: x, + num_workers: Optional[int] = 0, + allow_missing_keys: bool = False, + ) -> None: + """ + Args: + keys: the key of expected data in the dict, invert transforms on it. + it also can be a list of keys, will invert transform for each of them, like: ["pred", "pred_class2"]. + transform: the previous callable transform that applied on input data. + loader: data loader used to run transforms and generate the batch of data. + orig_keys: the key of the original input data in the dict. will get the applied transform information + for this input data, then invert them for the expected data with `keys`. + It can also be a list of keys, each matches to the `keys` data. + meta_key_postfix: use `{orig_key}_{postfix}` to to fetch the meta data from dict, + default is `meta_dict`, the meta data is a dictionary object. + For example, to handle orig_key `image`, read/write `affine` matrices from the + metadata `image_meta_dict` dictionary's `affine` field. + collate_fn: how to collate data after inverse transformations. + default won't do any collation, so the output will be a list of size batch size. + postfix: will save the inverted result into dict with key `{key}_{postfix}`. + nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms, + default to `True`. If `False`, use the same interpolation mode as the original transform. + it also can be a list of bool, each matches to the `keys` data. + to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`. + it also can be a list of bool, each matches to the `keys` data. + device: if converted to Tensor, move the inverted results to target device before `post_func`, + default to "cpu", it also can be a list of string or `torch.device`, + each matches to the `keys` data. + post_func: post processing for the inverted data, should be a callable function. + it also can be a list of callable, each matches to the `keys` data. + num_workers: number of workers when run data loader for inverse transforms, + default to 0 as only run one iteration and multi-processing may be even slower. + Set to `None`, to use the `num_workers` of the input transform data loader. + allow_missing_keys: don't raise exception if key is missing. + + """ + super().__init__(keys, allow_missing_keys) + self.transform = transform + self.inverter = BatchInverseTransform( + transform=transform, + loader=loader, + collate_fn=collate_fn, + num_workers=num_workers, + ) + self.orig_keys = ensure_tuple_rep(orig_keys, len(self.keys)) + self.meta_key_postfix = meta_key_postfix + self.postfix = postfix + self.nearest_interp = ensure_tuple_rep(nearest_interp, len(self.keys)) + self.to_tensor = ensure_tuple_rep(to_tensor, len(self.keys)) + self.device = ensure_tuple_rep(device, len(self.keys)) + self.post_func = ensure_tuple_rep(post_func, len(self.keys)) + self._totensor = ToTensor() + + def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: + d = dict(data) + for key, orig_key, nearest_interp, to_tensor, device, post_func in self.key_iterator( + d, self.orig_keys, self.nearest_interp, self.to_tensor, self.device, self.post_func + ): + transform_key = orig_key + InverseKeys.KEY_SUFFIX + if transform_key not in d: + warnings.warn(f"transform info of `{orig_key}` is not available or no InvertibleTransform applied.") + continue + + transform_info = d[transform_key] + if nearest_interp: + transform_info = convert_inverse_interp_mode( + trans_info=deepcopy(transform_info), + mode="nearest", + align_corners=None, + ) + + input = d[key] + if isinstance(input, torch.Tensor): + input = input.detach() + # construct the input dict data for BatchInverseTransform + input_dict = { + orig_key: input, + transform_key: transform_info, + } + meta_dict_key = f"{orig_key}_{self.meta_key_postfix}" + if meta_dict_key in d: + input_dict[meta_dict_key] = d[meta_dict_key] + + with allow_missing_keys_mode(self.transform): # type: ignore + inverted = self.inverter(input_dict) + + # save the inverted data + inverted_key = f"{key}_{self.postfix}" + d[inverted_key] = [ + post_func(self._totensor(i[orig_key]).to(device) if to_tensor else i[orig_key]) for i in inverted + ] + + # save the inverted meta dict + if meta_dict_key in d: + d[f"{inverted_key}_{self.meta_key_postfix}"] = [i.get(meta_dict_key) for i in inverted] + return d + + ActivationsD = ActivationsDict = Activationsd AsDiscreteD = AsDiscreteDict = AsDiscreted KeepLargestConnectedComponentD = KeepLargestConnectedComponentDict = KeepLargestConnectedComponentd @@ -407,3 +544,4 @@ def __call__(self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]): ProbNMSD = ProbNMSDict = ProbNMSd VoteEnsembleD = VoteEnsembleDict = VoteEnsembled DecollateD = DecollateDict = Decollated +InvertD = InvertDict = Invertd diff --git a/tests/min_tests.py b/tests/min_tests.py index 7a2b4a8bc1..80975da900 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -123,6 +123,7 @@ def run_testsuit(): "test_handler_transform_inverter", "test_testtimeaugmentation", "test_cachedataset_persistent_workers", + "test_invertd", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_inverse.py b/tests/test_inverse.py index b4762fc5a9..14ea0ef65a 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -21,12 +21,12 @@ from parameterized import parameterized from monai.data import CacheDataset, DataLoader, create_test_image_2d, create_test_image_3d -from monai.data.inverse_batch_transform import BatchInverseTransform from monai.data.utils import decollate_batch from monai.networks.nets import UNet from monai.transforms import ( AddChanneld, Affined, + BatchInverseTransform, BorderPadd, CenterSpatialCropd, Compose, diff --git a/tests/test_invertd.py b/tests/test_invertd.py new file mode 100644 index 0000000000..f46ff4170c --- /dev/null +++ b/tests/test_invertd.py @@ -0,0 +1,102 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np +import torch + +from monai.data import CacheDataset, DataLoader, create_test_image_3d +from monai.transforms import ( + AddChanneld, + CastToTyped, + Compose, + Invertd, + LoadImaged, + Orientationd, + RandAffined, + RandAxisFlipd, + RandFlipd, + RandRotate90d, + RandRotated, + RandZoomd, + ResizeWithPadOrCropd, + ScaleIntensityd, + Spacingd, + ToTensord, +) +from monai.utils.misc import set_determinism +from tests.utils import make_nifti_image + +KEYS = ["image", "label"] + + +class TestInvertd(unittest.TestCase): + def test_invert(self): + set_determinism(seed=0) + im_fname, seg_fname = [make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)] + transform = Compose( + [ + LoadImaged(KEYS), + AddChanneld(KEYS), + Orientationd(KEYS, "RPS"), + Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), + ScaleIntensityd("image", minv=1, maxv=10), + RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), + RandAxisFlipd(KEYS, prob=0.5), + RandRotate90d(KEYS, spatial_axes=(1, 2)), + RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), + RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), + RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), + ResizeWithPadOrCropd(KEYS, 100), + ToTensord("image"), # test to support both Tensor and Numpy array when inverting + CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), + ] + ) + data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] + + # num workers = 0 for mac or gpu transforms + num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2 + + dataset = CacheDataset(data, transform=transform, progress=False) + loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) + inverter = Invertd( + keys=["image", "label"], + transform=transform, + loader=loader, + orig_keys="label", + nearest_interp=True, + postfix="inverted", + to_tensor=[True, False], + device="cpu", + num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, + ) + + # execute 1 epoch + for d in loader: + d = inverter(d) + # this unit test only covers basic function, test_handler_transform_inverter covers more + self.assertTupleEqual(d["image"].shape[1:], (1, 100, 100, 100)) + self.assertTupleEqual(d["label"].shape[1:], (1, 100, 100, 100)) + # check the nearest inerpolation mode + for i in d["image_inverted"]: + torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float)) + self.assertTupleEqual(i.shape, (1, 100, 101, 107)) + for i in d["label_inverted"]: + np.testing.assert_allclose(i.astype(np.uint8).astype(np.float32), i.astype(np.float32)) + self.assertTupleEqual(i.shape, (1, 100, 101, 107)) + + set_determinism(seed=None) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index 34e8c62bd5..c430a43e74 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -142,12 +142,12 @@ def test_fail_random_but_not_invertible(self): TestTimeAugmentation(transforms, None, None, None) def test_single_transform(self): - transforms = RandFlipd(["image", "label"]) + transforms = RandFlipd(["image", "label"], prob=1.0) tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x) tta(self.get_data(1, (20, 20))) def test_image_no_label(self): - transforms = RandFlipd(["image"]) + transforms = RandFlipd(["image"], prob=1.0) tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x, label_key="image") tta(self.get_data(1, (20, 20), include_label=False))