Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 31 additions & 20 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@

class Compose(Randomizable, InvertibleTransform):
"""
``Compose`` provides the ability to chain a series of calls together in a
sequence. Each transform in the sequence must take a single argument and
return a single value, so that the transforms can be called in a chain.
``Compose`` provides the ability to chain a series of callables together in
a sequential manner. Each transform in the sequence must take a single
argument and return a single value.

``Compose`` can be used in two ways:

Expand All @@ -48,23 +48,31 @@ class Compose(Randomizable, InvertibleTransform):
dictionary. It is required that the dictionary is copied between input
and output of each transform.

If some transform generates a list batch of data in the transform chain,
every item in the list is still a dictionary, and all the following
transforms will apply to every item of the list, for example:
If some transform takes a data item dictionary as input, and returns a
sequence of data items in the transform chain, all following transforms
will be applied to each item of this list if `map_items` is `True` (the
default). If `map_items` is `False`, the returned sequence is passed whole
to the next callable in the chain.

Comment thread
wyli marked this conversation as resolved.
#. transformA normalizes the intensity of 'img' field in the dict data.
#. transformB crops out a list batch of images on 'img' and 'seg' field.
And constructs a list of dict data, other fields are copied::
For example:

{ [{ {
'img': [1, 2], 'img': [1], 'img': [2],
'seg': [1, 2], 'seg': [1], 'seg': [2],
'extra': 123, --> 'extra': 123, 'extra': 123,
'shape': 'CHWD' 'shape': 'CHWD' 'shape': 'CHWD'
} }, }]
A `Compose([transformA, transformB, transformC],
map_items=True)(data_dict)` could achieve the following patch-based
transformation on the `data_dict` input:

#. transformC then randomly rotates or flips 'img' and 'seg' fields of
every dictionary item in the list.
#. transformA normalizes the intensity of 'img' field in the `data_dict`.
#. transformB crops out image patches from the 'img' and 'seg' of
`data_dict`, and return a list of three patch samples::

{'img': 3x100x100 data, 'seg': 1x100x100 data, 'shape': (100, 100)}
applying transformB
---------->
[{'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)},
{'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)},
{'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)},]

#. transformC then randomly rotates or flips 'img' and 'seg' of
each dictionary item in the list returned by transformB.

The composed transforms will be set the same global random seed if user called
`set_determinism()`.
Expand Down Expand Up @@ -93,10 +101,13 @@ class Compose(Randomizable, InvertibleTransform):
them are called on the labels.
"""

def __init__(self, transforms: Optional[Union[Sequence[Callable], Callable]] = None) -> None:
def __init__(
self, transforms: Optional[Union[Sequence[Callable], Callable]] = None, map_items: bool = True
) -> None:
if transforms is None:
transforms = []
self.transforms = ensure_tuple(transforms)
self.map_items = map_items
self.set_random_state(seed=get_seed())

def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose":
Expand Down Expand Up @@ -141,7 +152,7 @@ def __len__(self):

def __call__(self, input_):
for _transform in self.transforms:
input_ = apply_transform(_transform, input_)
input_ = apply_transform(_transform, input_, self.map_items)
return input_

def inverse(self, data):
Expand All @@ -151,5 +162,5 @@ def inverse(self, data):

# loop backwards over transforms
for t in reversed(invertible_transforms):
data = apply_transform(t.inverse, data)
data = apply_transform(t.inverse, data, self.map_items)
return data
23 changes: 23 additions & 0 deletions tests/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,29 @@ def c(d): # transform to handle dict data
for item in value:
self.assertDictEqual(item, {"a": 2, "b": 1, "c": 2})

def test_list_dict_compose_no_map(self):
def a(d): # transform to handle dict data
d = dict(d)
d["a"] += 1
return d

def b(d): # transform to generate a batch list of data
d = dict(d)
d["b"] += 1
d = [d] * 5
return d

def c(d): # transform to handle dict data
d = [dict(di) for di in d]
for di in d:
di["c"] += 1
return d

transforms = Compose([a, a, b, c, c], map_items=False)
value = transforms({"a": 0, "b": 0, "c": 0})
for item in value:
self.assertDictEqual(item, {"a": 2, "b": 1, "c": 2})

def test_random_compose(self):
class _Acc(Randomizable):
self.rand = 0.0
Expand Down