diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 1d60c34c3e..7b55a993a1 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -17,6 +17,7 @@ import numpy as np +import monai from monai.transforms.inverse import InvertibleTransform # For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform) @@ -254,26 +255,27 @@ def __call__(self, data): _transform = self.transforms[index] data = apply_transform(_transform, data, self.map_items, self.unpack_items, self.log_stats) # if the data is a mapping (dictionary), append the OneOf transform to the end - if isinstance(data, Mapping): - for key in data.keys(): - if self.trace_key(key) in data: + if isinstance(data, monai.data.MetaTensor): + self.push_transform(data, extra_info={"index": index}) + elif isinstance(data, Mapping): + for key in data: # dictionary not change size during iteration + if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data: self.push_transform(data, key, extra_info={"index": index}) return data def inverse(self, data): if len(self.transforms) == 0: return data - if not isinstance(data, Mapping): - raise RuntimeError("Inverse only implemented for Mapping (dictionary) data") - # loop until we get an index and then break (since they'll all be the same) index = None - for key in data.keys(): - if self.trace_key(key) in data: - # get the index of the applied OneOf transform - index = self.get_most_recent_transform(data, key)[TraceKeys.EXTRA_INFO]["index"] - # and then remove the OneOf transform - self.pop_transform(data, key) + if isinstance(data, monai.data.MetaTensor): + index = self.pop_transform(data)[TraceKeys.EXTRA_INFO]["index"] + elif isinstance(data, Mapping): + for key in data: + if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data: + index = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO]["index"] + else: + raise RuntimeError("Inverse only implemented for Mapping (dictionary) or MetaTensor data.") if index is None: # no invertible transforms have been applied return data diff --git a/tests/test_one_of.py b/tests/test_one_of.py index 29d13d7d0c..8171acb6da 100644 --- a/tests/test_one_of.py +++ b/tests/test_one_of.py @@ -15,11 +15,15 @@ import numpy as np from parameterized import parameterized +from monai.data import MetaTensor from monai.transforms import ( InvertibleTransform, OneOf, + RandScaleIntensity, RandScaleIntensityd, + RandShiftIntensity, RandShiftIntensityd, + Resize, Resized, TraceableTransform, Transform, @@ -106,10 +110,10 @@ def __init__(self, keys): KEYS = ["x", "y"] TEST_INVERSES = [ - (OneOf((InvA(KEYS), InvB(KEYS))), True), - (OneOf((OneOf((InvA(KEYS), InvB(KEYS))), OneOf((InvB(KEYS), InvA(KEYS))))), True), - (OneOf((Compose((InvA(KEYS), InvB(KEYS))), Compose((InvB(KEYS), InvA(KEYS))))), True), - (OneOf((NonInv(KEYS), NonInv(KEYS))), False), + (OneOf((InvA(KEYS), InvB(KEYS))), True, True), + (OneOf((OneOf((InvA(KEYS), InvB(KEYS))), OneOf((InvB(KEYS), InvA(KEYS))))), True, False), + (OneOf((Compose((InvA(KEYS), InvB(KEYS))), Compose((InvB(KEYS), InvA(KEYS))))), True, False), + (OneOf((NonInv(KEYS), NonInv(KEYS))), False, False), ] @@ -148,13 +152,17 @@ def _match(a, b): _match(p, f) @parameterized.expand(TEST_INVERSES) - def test_inverse(self, transform, invertible): - data = {k: (i + 1) * 10.0 for i, k in enumerate(KEYS)} + def test_inverse(self, transform, invertible, use_metatensor): + data = {k: (i + 1) * 10.0 if not use_metatensor else MetaTensor((i + 1) * 10.0) for i, k in enumerate(KEYS)} fwd_data = transform(data) if invertible: for k in KEYS: - t = fwd_data[TraceableTransform.trace_key(k)][-1] + t = ( + fwd_data[TraceableTransform.trace_key(k)][-1] + if not use_metatensor + else fwd_data[k].applied_operations[-1] + ) # make sure the OneOf index was stored self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__) # make sure index exists and is in bounds @@ -166,9 +174,11 @@ def test_inverse(self, transform, invertible): if invertible: for k in KEYS: # check transform was removed - self.assertTrue( - len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)]) - ) + if not use_metatensor: + self.assertTrue( + len(fwd_inv_data[TraceableTransform.trace_key(k)]) + < len(fwd_data[TraceableTransform.trace_key(k)]) + ) # check data is same as original (and different from forward) self.assertEqual(fwd_inv_data[k], data[k]) self.assertNotEqual(fwd_inv_data[k], fwd_data[k]) @@ -186,15 +196,34 @@ def test_inverse_compose(self): RandShiftIntensityd(keys="img", offsets=0.5, prob=1.0), ] ), + OneOf( + [ + RandScaleIntensityd(keys="img", factors=0.5, prob=1.0), + RandShiftIntensityd(keys="img", offsets=0.5, prob=1.0), + ] + ), ] ) transform.set_random_state(seed=0) result = transform({"img": np.ones((1, 101, 102, 103))}) - result = transform.inverse(result) # invert to the original spatial shape self.assertTupleEqual(result["img"].shape, (1, 101, 102, 103)) + def test_inverse_metatensor(self): + transform = Compose( + [ + Resize(spatial_size=[100, 100, 100]), + OneOf([RandScaleIntensity(factors=0.5, prob=1.0), RandShiftIntensity(offsets=0.5, prob=1.0)]), + OneOf([RandScaleIntensity(factors=0.5, prob=1.0), RandShiftIntensity(offsets=0.5, prob=1.0)]), + ] + ) + transform.set_random_state(seed=0) + result = transform(np.ones((1, 101, 102, 103))) + self.assertTupleEqual(result.shape, (1, 100, 100, 100)) + result = transform.inverse(result) + self.assertTupleEqual(result.shape, (1, 101, 102, 103)) + def test_one_of(self): p = OneOf((A(), B(), C()), (1, 2, 1)) counts = [0] * 3