diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index ce965b8b18..085d5dcfc4 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -34,9 +34,9 @@ class Compose(Randomizable, InvertibleTransform): """ - ``Compose`` provides the ability to chain a series of calls together in a - sequence. Each transform in the sequence must take a single argument and - return a single value, so that the transforms can be called in a chain. + ``Compose`` provides the ability to chain a series of callables together in + a sequential manner. Each transform in the sequence must take a single + argument and return a single value. ``Compose`` can be used in two ways: @@ -48,23 +48,31 @@ class Compose(Randomizable, InvertibleTransform): dictionary. It is required that the dictionary is copied between input and output of each transform. - If some transform generates a list batch of data in the transform chain, - every item in the list is still a dictionary, and all the following - transforms will apply to every item of the list, for example: + If some transform takes a data item dictionary as input, and returns a + sequence of data items in the transform chain, all following transforms + will be applied to each item of this list if `map_items` is `True` (the + default). If `map_items` is `False`, the returned sequence is passed whole + to the next callable in the chain. - #. transformA normalizes the intensity of 'img' field in the dict data. - #. transformB crops out a list batch of images on 'img' and 'seg' field. - And constructs a list of dict data, other fields are copied:: + For example: - { [{ { - 'img': [1, 2], 'img': [1], 'img': [2], - 'seg': [1, 2], 'seg': [1], 'seg': [2], - 'extra': 123, --> 'extra': 123, 'extra': 123, - 'shape': 'CHWD' 'shape': 'CHWD' 'shape': 'CHWD' - } }, }] + A `Compose([transformA, transformB, transformC], + map_items=True)(data_dict)` could achieve the following patch-based + transformation on the `data_dict` input: - #. transformC then randomly rotates or flips 'img' and 'seg' fields of - every dictionary item in the list. + #. transformA normalizes the intensity of 'img' field in the `data_dict`. + #. transformB crops out image patches from the 'img' and 'seg' of + `data_dict`, and return a list of three patch samples:: + + {'img': 3x100x100 data, 'seg': 1x100x100 data, 'shape': (100, 100)} + applying transformB + ----------> + [{'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)}, + {'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)}, + {'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)},] + + #. transformC then randomly rotates or flips 'img' and 'seg' of + each dictionary item in the list returned by transformB. The composed transforms will be set the same global random seed if user called `set_determinism()`. @@ -93,10 +101,13 @@ class Compose(Randomizable, InvertibleTransform): them are called on the labels. """ - def __init__(self, transforms: Optional[Union[Sequence[Callable], Callable]] = None) -> None: + def __init__( + self, transforms: Optional[Union[Sequence[Callable], Callable]] = None, map_items: bool = True + ) -> None: if transforms is None: transforms = [] self.transforms = ensure_tuple(transforms) + self.map_items = map_items self.set_random_state(seed=get_seed()) def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose": @@ -141,7 +152,7 @@ def __len__(self): def __call__(self, input_): for _transform in self.transforms: - input_ = apply_transform(_transform, input_) + input_ = apply_transform(_transform, input_, self.map_items) return input_ def inverse(self, data): @@ -151,5 +162,5 @@ def inverse(self, data): # loop backwards over transforms for t in reversed(invertible_transforms): - data = apply_transform(t.inverse, data) + data = apply_transform(t.inverse, data, self.map_items) return data diff --git a/tests/test_compose.py b/tests/test_compose.py index 97b044af8f..77736a4c77 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -79,6 +79,29 @@ def c(d): # transform to handle dict data for item in value: self.assertDictEqual(item, {"a": 2, "b": 1, "c": 2}) + def test_list_dict_compose_no_map(self): + def a(d): # transform to handle dict data + d = dict(d) + d["a"] += 1 + return d + + def b(d): # transform to generate a batch list of data + d = dict(d) + d["b"] += 1 + d = [d] * 5 + return d + + def c(d): # transform to handle dict data + d = [dict(di) for di in d] + for di in d: + di["c"] += 1 + return d + + transforms = Compose([a, a, b, c, c], map_items=False) + value = transforms({"a": 0, "b": 0, "c": 0}) + for item in value: + self.assertDictEqual(item, {"a": 2, "b": 1, "c": 2}) + def test_random_compose(self): class _Acc(Randomizable): self.rand = 0.0