From ff0892e34be1200ad997273a6625e3a41c68b8b7 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 29 Apr 2021 20:09:50 +0100 Subject: [PATCH 1/2] update compose Signed-off-by: Wenqi Li --- monai/transforms/compose.py | 50 ++++++++++++++++++++++--------------- tests/test_compose.py | 23 +++++++++++++++++ 2 files changed, 53 insertions(+), 20 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index ce965b8b18..8654deadc5 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -34,9 +34,9 @@ class Compose(Randomizable, InvertibleTransform): """ - ``Compose`` provides the ability to chain a series of calls together in a - sequence. Each transform in the sequence must take a single argument and - return a single value, so that the transforms can be called in a chain. + ``Compose`` provides the ability to chain a series of callables together in + a sequential manner. Each transform in the sequence must take a single + argument and return a single value. ``Compose`` can be used in two ways: @@ -48,23 +48,30 @@ class Compose(Randomizable, InvertibleTransform): dictionary. It is required that the dictionary is copied between input and output of each transform. - If some transform generates a list batch of data in the transform chain, - every item in the list is still a dictionary, and all the following - transforms will apply to every item of the list, for example: + If some transform takes a data item dictionary as input, and return a list + of data items in the transform chain, and all the following transforms will + be applied to each item of the list, this behaviour is enabled by default + with `map_items=True`. - #. transformA normalizes the intensity of 'img' field in the dict data. - #. transformB crops out a list batch of images on 'img' and 'seg' field. - And constructs a list of dict data, other fields are copied:: + For example: - { [{ { - 'img': [1, 2], 'img': [1], 'img': [2], - 'seg': [1, 2], 'seg': [1], 'seg': [2], - 'extra': 123, --> 'extra': 123, 'extra': 123, - 'shape': 'CHWD' 'shape': 'CHWD' 'shape': 'CHWD' - } }, }] + A `Compose([transformA, transformB, transformC], + map_items=True)(data_dict)` could achieve the following patch-based + transformation on the `data_dict` input: - #. transformC then randomly rotates or flips 'img' and 'seg' fields of - every dictionary item in the list. + #. transformA normalizes the intensity of 'img' field in the `data_dict`. + #. transformB crops out image patches from the 'img' and 'seg' of + `data_dict`, and return a list of three patch samples:: + + {'img': 3x100x100 data, 'seg': 1x100x100 data, 'shape': (100, 100)} + applying transformB + ----------> + [{'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)}, + {'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)}, + {'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)},] + + #. transformC then randomly rotates or flips 'img' and 'seg' of + each dictionary item in the list returned by transformB. The composed transforms will be set the same global random seed if user called `set_determinism()`. @@ -93,10 +100,13 @@ class Compose(Randomizable, InvertibleTransform): them are called on the labels. """ - def __init__(self, transforms: Optional[Union[Sequence[Callable], Callable]] = None) -> None: + def __init__( + self, transforms: Optional[Union[Sequence[Callable], Callable]] = None, map_items: bool = True + ) -> None: if transforms is None: transforms = [] self.transforms = ensure_tuple(transforms) + self.map_items = map_items self.set_random_state(seed=get_seed()) def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose": @@ -141,7 +151,7 @@ def __len__(self): def __call__(self, input_): for _transform in self.transforms: - input_ = apply_transform(_transform, input_) + input_ = apply_transform(_transform, input_, self.map_items) return input_ def inverse(self, data): @@ -151,5 +161,5 @@ def inverse(self, data): # loop backwards over transforms for t in reversed(invertible_transforms): - data = apply_transform(t.inverse, data) + data = apply_transform(t.inverse, data, self.map_items) return data diff --git a/tests/test_compose.py b/tests/test_compose.py index 97b044af8f..77736a4c77 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -79,6 +79,29 @@ def c(d): # transform to handle dict data for item in value: self.assertDictEqual(item, {"a": 2, "b": 1, "c": 2}) + def test_list_dict_compose_no_map(self): + def a(d): # transform to handle dict data + d = dict(d) + d["a"] += 1 + return d + + def b(d): # transform to generate a batch list of data + d = dict(d) + d["b"] += 1 + d = [d] * 5 + return d + + def c(d): # transform to handle dict data + d = [dict(di) for di in d] + for di in d: + di["c"] += 1 + return d + + transforms = Compose([a, a, b, c, c], map_items=False) + value = transforms({"a": 0, "b": 0, "c": 0}) + for item in value: + self.assertDictEqual(item, {"a": 2, "b": 1, "c": 2}) + def test_random_compose(self): class _Acc(Randomizable): self.rand = 0.0 From bf914d4e937ee1c96358aa924d1dcc9ae5a1de5e Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 29 Apr 2021 21:25:12 +0100 Subject: [PATCH 2/2] update based on comments Signed-off-by: Wenqi Li --- monai/transforms/compose.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 8654deadc5..085d5dcfc4 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -48,10 +48,11 @@ class Compose(Randomizable, InvertibleTransform): dictionary. It is required that the dictionary is copied between input and output of each transform. - If some transform takes a data item dictionary as input, and return a list - of data items in the transform chain, and all the following transforms will - be applied to each item of the list, this behaviour is enabled by default - with `map_items=True`. + If some transform takes a data item dictionary as input, and returns a + sequence of data items in the transform chain, all following transforms + will be applied to each item of this list if `map_items` is `True` (the + default). If `map_items` is `False`, the returned sequence is passed whole + to the next callable in the chain. For example: