Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np

import monai
from monai.transforms.inverse import InvertibleTransform

# For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform)
Expand Down Expand Up @@ -254,26 +255,27 @@ def __call__(self, data):
_transform = self.transforms[index]
data = apply_transform(_transform, data, self.map_items, self.unpack_items, self.log_stats)
# if the data is a mapping (dictionary), append the OneOf transform to the end
if isinstance(data, Mapping):
for key in data.keys():
if self.trace_key(key) in data:
if isinstance(data, monai.data.MetaTensor):
self.push_transform(data, extra_info={"index": index})
elif isinstance(data, Mapping):
for key in data: # dictionary not change size during iteration
if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data:
self.push_transform(data, key, extra_info={"index": index})
return data

def inverse(self, data):
if len(self.transforms) == 0:
return data
if not isinstance(data, Mapping):
raise RuntimeError("Inverse only implemented for Mapping (dictionary) data")

# loop until we get an index and then break (since they'll all be the same)
index = None
for key in data.keys():
if self.trace_key(key) in data:
# get the index of the applied OneOf transform
index = self.get_most_recent_transform(data, key)[TraceKeys.EXTRA_INFO]["index"]
# and then remove the OneOf transform
self.pop_transform(data, key)
if isinstance(data, monai.data.MetaTensor):
index = self.pop_transform(data)[TraceKeys.EXTRA_INFO]["index"]
elif isinstance(data, Mapping):
for key in data:
if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data:
index = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO]["index"]
else:
raise RuntimeError("Inverse only implemented for Mapping (dictionary) or MetaTensor data.")
if index is None:
# no invertible transforms have been applied
return data
Expand Down
51 changes: 40 additions & 11 deletions tests/test_one_of.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
import numpy as np
from parameterized import parameterized

from monai.data import MetaTensor
from monai.transforms import (
InvertibleTransform,
OneOf,
RandScaleIntensity,
RandScaleIntensityd,
RandShiftIntensity,
RandShiftIntensityd,
Resize,
Resized,
TraceableTransform,
Transform,
Expand Down Expand Up @@ -106,10 +110,10 @@ def __init__(self, keys):

KEYS = ["x", "y"]
TEST_INVERSES = [
(OneOf((InvA(KEYS), InvB(KEYS))), True),
(OneOf((OneOf((InvA(KEYS), InvB(KEYS))), OneOf((InvB(KEYS), InvA(KEYS))))), True),
(OneOf((Compose((InvA(KEYS), InvB(KEYS))), Compose((InvB(KEYS), InvA(KEYS))))), True),
(OneOf((NonInv(KEYS), NonInv(KEYS))), False),
(OneOf((InvA(KEYS), InvB(KEYS))), True, True),
(OneOf((OneOf((InvA(KEYS), InvB(KEYS))), OneOf((InvB(KEYS), InvA(KEYS))))), True, False),
(OneOf((Compose((InvA(KEYS), InvB(KEYS))), Compose((InvB(KEYS), InvA(KEYS))))), True, False),
(OneOf((NonInv(KEYS), NonInv(KEYS))), False, False),
]


Expand Down Expand Up @@ -148,13 +152,17 @@ def _match(a, b):
_match(p, f)

@parameterized.expand(TEST_INVERSES)
def test_inverse(self, transform, invertible):
data = {k: (i + 1) * 10.0 for i, k in enumerate(KEYS)}
def test_inverse(self, transform, invertible, use_metatensor):
data = {k: (i + 1) * 10.0 if not use_metatensor else MetaTensor((i + 1) * 10.0) for i, k in enumerate(KEYS)}
fwd_data = transform(data)

if invertible:
for k in KEYS:
t = fwd_data[TraceableTransform.trace_key(k)][-1]
t = (
fwd_data[TraceableTransform.trace_key(k)][-1]
if not use_metatensor
else fwd_data[k].applied_operations[-1]
)
# make sure the OneOf index was stored
self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__)
# make sure index exists and is in bounds
Expand All @@ -166,9 +174,11 @@ def test_inverse(self, transform, invertible):
if invertible:
for k in KEYS:
# check transform was removed
self.assertTrue(
len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)])
)
if not use_metatensor:
self.assertTrue(
len(fwd_inv_data[TraceableTransform.trace_key(k)])
< len(fwd_data[TraceableTransform.trace_key(k)])
)
# check data is same as original (and different from forward)
self.assertEqual(fwd_inv_data[k], data[k])
self.assertNotEqual(fwd_inv_data[k], fwd_data[k])
Expand All @@ -186,15 +196,34 @@ def test_inverse_compose(self):
RandShiftIntensityd(keys="img", offsets=0.5, prob=1.0),
]
),
OneOf(
[
RandScaleIntensityd(keys="img", factors=0.5, prob=1.0),
RandShiftIntensityd(keys="img", offsets=0.5, prob=1.0),
]
),
]
)
transform.set_random_state(seed=0)
result = transform({"img": np.ones((1, 101, 102, 103))})

result = transform.inverse(result)
# invert to the original spatial shape
self.assertTupleEqual(result["img"].shape, (1, 101, 102, 103))

def test_inverse_metatensor(self):
transform = Compose(
[
Resize(spatial_size=[100, 100, 100]),
OneOf([RandScaleIntensity(factors=0.5, prob=1.0), RandShiftIntensity(offsets=0.5, prob=1.0)]),
OneOf([RandScaleIntensity(factors=0.5, prob=1.0), RandShiftIntensity(offsets=0.5, prob=1.0)]),
]
)
transform.set_random_state(seed=0)
result = transform(np.ones((1, 101, 102, 103)))
self.assertTupleEqual(result.shape, (1, 100, 100, 100))
result = transform.inverse(result)
self.assertTupleEqual(result.shape, (1, 101, 102, 103))

def test_one_of(self):
p = OneOf((A(), B(), C()), (1, 2, 1))
counts = [0] * 3
Expand Down