Skip to content
Merged
4 changes: 0 additions & 4 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,6 @@ ThreadBuffer
.. autoclass:: monai.data.ThreadBuffer


BatchInverseTransform
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: monai.data.BatchInverseTransform

TestTimeAugmentation
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: monai.data.TestTimeAugmentation
11 changes: 11 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ Generic Interfaces
.. autoclass:: InvertibleTransform
:members:

`BatchInverseTransform`
^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: BatchInverseTransform
:members:


Vanilla Transforms
------------------
Expand Down Expand Up @@ -836,6 +841,12 @@ Post-processing (Dict)
:members:
:special-members: __call__

`Invertd`
"""""""""
.. autoclass:: Invertd
:members:
:special-members: __call__

Spatial (Dict)
^^^^^^^^^^^^^^

Expand Down
1 change: 0 additions & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter
from .image_dataset import ImageDataset
from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader
from .inverse_batch_transform import BatchInverseTransform
from .iterable_dataset import IterableDataset
from .nifti_saver import NiftiSaver
from .nifti_writer import write_nifti
Expand Down
2 changes: 1 addition & 1 deletion monai/data/test_time_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

from monai.data.dataloader import DataLoader
from monai.data.dataset import Dataset
from monai.data.inverse_batch_transform import BatchInverseTransform
from monai.data.utils import list_data_collate, pad_list_data_collate
from monai.transforms.compose import Compose
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.inverse_batch_transform import BatchInverseTransform
from monai.transforms.transform import Randomizable
from monai.transforms.utils import allow_missing_keys_mode
from monai.utils.enums import CommonKeys, InverseKeys
Expand Down
70 changes: 22 additions & 48 deletions monai/handlers/transform_inverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from copy import deepcopy
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union

import torch
from torch.utils.data import DataLoader as TorchDataLoader

from monai.data import BatchInverseTransform
from monai.data.utils import no_collation
from monai.engines.utils import CommonKeys, IterationEvents
from monai.transforms import InvertibleTransform, ToTensor, allow_missing_keys_mode, convert_inverse_interp_mode
from monai.utils import InverseKeys, ensure_tuple, ensure_tuple_rep, exact_version, optional_import
from monai.transforms import Invertd, InvertibleTransform
from monai.utils import ensure_tuple, exact_version, optional_import

Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
if TYPE_CHECKING:
Expand All @@ -33,6 +30,7 @@ class TransformInverter:
"""
Ignite handler to automatically invert `transforms`.
It takes `engine.state.output` as the input data and uses the transforms information from `engine.state.batch`.
Expect both `engine.state.output` and `engine.state.batch` to be dictionary data.
The inverted data are stored in `engine.state.output` with key: "{output_key}_{postfix}".
And the inverted meta dict will be stored in `engine.state.batch`
with key: "{output_key}_{postfix}_{meta_key_postfix}".
Expand Down Expand Up @@ -85,22 +83,23 @@ def __init__(
Set to `None`, to use the `num_workers` of the input transform data loader.

"""
self.transform = transform
self.inverter = BatchInverseTransform(
self.inverter = Invertd(
keys=output_keys,
transform=transform,
loader=loader,
orig_keys=batch_keys,
meta_key_postfix=meta_key_postfix,
collate_fn=collate_fn,
postfix=postfix,
nearest_interp=nearest_interp,
to_tensor=to_tensor,
device=device,
post_func=post_func,
num_workers=num_workers,
)
self.output_keys = ensure_tuple(output_keys)
self.batch_keys = ensure_tuple_rep(batch_keys, len(self.output_keys))
self.meta_key_postfix = meta_key_postfix
self.postfix = postfix
self.nearest_interp = ensure_tuple_rep(nearest_interp, len(self.output_keys))
self.to_tensor = ensure_tuple_rep(to_tensor, len(self.output_keys))
self.device = ensure_tuple_rep(device, len(self.output_keys))
self.post_func = ensure_tuple_rep(post_func, len(self.output_keys))
self._totensor = ToTensor()

def attach(self, engine: Engine) -> None:
"""
Expand All @@ -114,42 +113,17 @@ def __call__(self, engine: Engine) -> None:
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
for output_key, batch_key, nearest_interp, to_tensor, device, post_func in zip(
self.output_keys, self.batch_keys, self.nearest_interp, self.to_tensor, self.device, self.post_func
):
transform_key = batch_key + InverseKeys.KEY_SUFFIX
if transform_key not in engine.state.batch:
warnings.warn(f"all the transforms on `{batch_key}` are not InvertibleTransform.")
continue

transform_info = engine.state.batch[transform_key]
if nearest_interp:
transform_info = convert_inverse_interp_mode(
trans_info=deepcopy(transform_info),
mode="nearest",
align_corners=None,
)

output = engine.state.output[output_key]
if isinstance(output, torch.Tensor):
output = output.detach()
segs_dict = {
batch_key: output,
transform_key: transform_info,
}
meta_dict_key = f"{batch_key}_{self.meta_key_postfix}"
if meta_dict_key in engine.state.batch:
segs_dict[meta_dict_key] = engine.state.batch[meta_dict_key]

with allow_missing_keys_mode(self.transform): # type: ignore
inverted = self.inverter(segs_dict)
# combine `batch` and `output` to temporarily act as 1 dict for post transform
data = dict(engine.state.batch)
Comment thread
Nic-Ma marked this conversation as resolved.
data.update(engine.state.output)
ret = self.inverter(data)

for output_key in self.output_keys:
# save the inverted data into state.output
inverted_key = f"{output_key}_{self.postfix}"
engine.state.output[inverted_key] = [
post_func(self._totensor(i[batch_key]).to(device) if to_tensor else i[batch_key]) for i in inverted
]

if inverted_key in ret:
engine.state.output[inverted_key] = ret[inverted_key]
# save the inverted meta dict into state.batch
if meta_dict_key in engine.state.batch:
engine.state.batch[f"{inverted_key}_{self.meta_key_postfix}"] = [i.get(meta_dict_key) for i in inverted]
meta_dict_key = f"{inverted_key}_{self.meta_key_postfix}"
if meta_dict_key in ret:
engine.state.batch[meta_dict_key] = ret[meta_dict_key]
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@
ThresholdIntensityDict,
)
from .inverse import InvertibleTransform
from .inverse_batch_transform import BatchInverseTransform
from .io.array import LoadImage, SaveImage
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
from .post.array import (
Expand All @@ -178,6 +179,9 @@
DecollateD,
DecollateDict,
Ensembled,
Invertd,
InvertD,
InvertDict,
KeepLargestConnectedComponentd,
KeepLargestConnectedComponentD,
KeepLargestConnectedComponentDict,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
# limitations under the License.

import warnings
from typing import Any, Callable, Dict, Hashable, Optional, Sequence
from typing import Any, Callable, Dict, Optional, Sequence

import numpy as np
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader as TorchDataLoader

from monai.data.dataloader import DataLoader
from monai.data.dataset import Dataset
from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate
from monai.transforms.croppad.batch import PadListDataCollate
from monai.transforms.inverse import InvertibleTransform
Expand All @@ -33,11 +32,11 @@ def __init__(
transform: InvertibleTransform,
pad_collation_used: bool,
) -> None:
super().__init__(data, transform)
self.data = data
self.invertible_transform = transform
self.pad_collation_used = pad_collation_used

def _transform(self, index: int) -> Dict[Hashable, np.ndarray]:
def __getitem__(self, index: int):
data = dict(self.data[index])
# If pad collation was used, then we need to undo this first
if self.pad_collation_used:
Expand All @@ -48,6 +47,9 @@ def _transform(self, index: int) -> Dict[Hashable, np.ndarray]:
return data
return self.invertible_transform.inverse(data)

def __len__(self) -> int:
return len(self.data)


class BatchInverseTransform(Transform):
"""
Expand Down
Loading