Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ Post processing
.. autoclass:: PostProcessing
:members:

Decollate batch
---------------
.. autoclass:: DecollateBatch
:members:

Utilities
---------
.. automodule:: monai.handlers.utils
Expand Down
8 changes: 4 additions & 4 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
from torch.utils.data.distributed import DistributedSampler

from monai.config import IgniteInfo
from monai.data import decollate_batch, rep_scalar_to_batch
from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
from monai.transforms import Transform
from monai.transforms import Decollated, Transform
from monai.utils import ensure_tuple, min_version, optional_import

from .utils import engine_apply_transform
Expand Down Expand Up @@ -186,8 +185,9 @@ def _register_decollate(self):
@self.on(IterationEvents.MODEL_COMPLETED)
def _decollate_data(engine: Engine) -> None:
# replicate the scalar values to make sure all the items have batch dimension, then decollate
engine.state.batch = decollate_batch(rep_scalar_to_batch(engine.state.batch), detach=True)
engine.state.output = decollate_batch(rep_scalar_to_batch(engine.state.output), detach=True)
transform = Decollated(keys=None, detach=True)
engine.state.batch = transform(engine.state.batch)
engine.state.output = transform(engine.state.output)

def _register_postprocessing(self, posttrans: Callable):
"""
Expand Down
1 change: 1 addition & 0 deletions monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .checkpoint_saver import CheckpointSaver
from .classification_saver import ClassificationSaver
from .confusion_matrix import ConfusionMatrix
from .decollate_batch import DecollateBatch
from .earlystop_handler import EarlyStopHandler
from .garbage_collector import GarbageCollector
from .hausdorff_distance import HausdorffDistance
Expand Down
94 changes: 94 additions & 0 deletions monai/handlers/decollate_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Optional

from monai.config import IgniteInfo, KeysCollection
from monai.engines.utils import IterationEvents
from monai.transforms import Decollated
from monai.utils import min_version, optional_import

Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
if TYPE_CHECKING:
from ignite.engine import Engine
else:
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")


class DecollateBatch:
"""
Ignite handler to execute the `decollate batch` logic for `engine.state.batch` and `engine.state.output`.
Typical usage is to set `decollate=False` in the engine and execute some postprocessing logic first
then decollate the batch, otherwise, engine will decollate batch before the postprocessing.

Args:
event: expected EVENT to attach the handler, should be "MODEL_COMPLETED" or "ITERATION_COMPLETED".
default to "MODEL_COMPLETED".
detach: whether to detach the tensors. scalars tensors will be detached into number types
instead of torch tensors.
decollate_batch: whether to decollate `engine.state.batch` of ignite engine.
batch_keys: if `decollate_batch=True`, specify the keys of the corresponding items to decollate
in `engine.state.batch`, note that it will delete other keys not specified. if None,
will decollate all the keys. it replicates the scalar values to every item of the decollated list.
decollate_output: whether to decollate `engine.state.output` of ignite engine.
output_keys: if `decollate_output=True`, specify the keys of the corresponding items to decollate
in `engine.state.output`, note that it will delete other keys not specified. if None,
will decollate all the keys. it replicates the scalar values to every item of the decollated list.
allow_missing_keys: don't raise exception if key is missing.

"""

def __init__(
self,
event: str = "MODEL_COMPLETED",
detach: bool = True,
decollate_batch: bool = True,
batch_keys: Optional[KeysCollection] = None,
decollate_output: bool = True,
output_keys: Optional[KeysCollection] = None,
allow_missing_keys: bool = False,
):
event = event.upper()
if event not in ("MODEL_COMPLETED", "ITERATION_COMPLETED"):
raise ValueError("event should be `MODEL_COMPLETED` or `ITERATION_COMPLETED`.")
self.event = event

self.batch_transform = (
Decollated(keys=batch_keys, detach=detach, allow_missing_keys=allow_missing_keys)
if decollate_batch
else None
)

self.output_transform = (
Decollated(keys=output_keys, detach=detach, allow_missing_keys=allow_missing_keys)
if decollate_output
else None
)

def attach(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
if self.event == "MODEL_COMPLETED":
engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self)
else:
engine.add_event_handler(Events.ITERATION_COMPLETED, self)
Comment thread
Nic-Ma marked this conversation as resolved.

def __call__(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
if self.batch_transform is not None:
engine.state.batch = self.batch_transform(engine.state.batch)
if self.output_transform is not None:
engine.state.output = self.output_transform(engine.state.output)
15 changes: 13 additions & 2 deletions monai/handlers/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,35 @@ class PostProcessing:
"""
Ignite handler to execute additional post processing after the post processing in engines.
So users can insert other handlers between engine postprocessing and this post processing handler.
If using components from `monai.transforms` as the `transform`, recommend to decollate `engine.state.batch`
and `engine.state.batch` in the engine(set `decollate=True`) or in the `DecollateBatch` handler first.

"""

def __init__(self, transform: Callable) -> None:
def __init__(self, transform: Callable, event: str = "MODEL_COMPLETED") -> None:
"""
Args:
transform: callable function to execute on the `engine.state.batch` and `engine.state.output`.
can also be composed transforms.
event: expected EVENT to attach the handler, should be "MODEL_COMPLETED" or "ITERATION_COMPLETED".
default to "MODEL_COMPLETED".

"""
self.transform = transform
event = event.upper()
if event not in ("MODEL_COMPLETED", "ITERATION_COMPLETED"):
raise ValueError("event should be `MODEL_COMPLETED` or `ITERATION_COMPLETED`.")
self.event = event

def attach(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self)
if self.event == "MODEL_COMPLETED":
engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self)
else:
engine.add_event_handler(Events.ITERATION_COMPLETED, self)

def __call__(self, engine: Engine) -> None:
"""
Expand Down
38 changes: 30 additions & 8 deletions monai/transforms/inverse_batch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
# limitations under the License.

import warnings
from typing import Any, Callable, Dict, Optional, Sequence
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader as TorchDataLoader

from monai.config import KeysCollection
from monai.data.dataloader import DataLoader
from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate
from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate, rep_scalar_to_batch
from monai.transforms.croppad.batch import PadListDataCollate
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import MapTransform, Transform
Expand Down Expand Up @@ -103,18 +104,39 @@ def __call__(self, data: Dict[str, Any]) -> Any:

class Decollated(MapTransform):
"""
Decollate a batch of data.
Note that unlike most MapTransforms, this will decollate all data, so keys are not needed.
Decollate a batch of data, if input a dictionary, it can also support to only decollate specified keys.
Note that unlike most MapTransforms, it will delete other keys not specified and if keys=None, will decollate
all the data in the input.
And it replicates the scalar values to every item of the decollated list.

Args:
keys: keys of the corresponding items to decollate, note that it will delete other keys not specified.
if None, will decollate all the keys. see also: :py:class:`monai.transforms.compose.MapTransform`.
detach: whether to detach the tensors. Scalars tensors will be detached into number types
instead of torch tensors.
allow_missing_keys: don't raise exception if key is missing.

"""

def __init__(self, keys="", detach: bool = True) -> None:
super().__init__(keys=keys)
def __init__(
self,
keys: Optional[KeysCollection] = None,
detach: bool = True,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.detach = detach

def __call__(self, data: dict):
return decollate_batch(data, detach=self.detach)
def __call__(self, data: Union[Dict, List]):
d: Union[Dict, List]
if len(self.keys) == 1 and self.keys[0] is None:
# it doesn't support `None` as the key
d = data
else:
if not isinstance(data, dict):
raise TypeError("input data is not a dictionary, but specified keys to decollate.")
d = {}
for key in self.key_iterator(data):
d[key] = data[key]

return decollate_batch(rep_scalar_to_batch(d), detach=self.detach)
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def run_testsuit():
"test_unetr",
"test_unetr_block",
"test_vit",
"test_handler_decollate_batch",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"

Expand Down
36 changes: 36 additions & 0 deletions tests/test_decollate.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,42 @@ def test_dict_examples(self):
out = decollate_batch(test_case, detach=False)
self.assertEqual(out[0]["out"], "test")

def test_decollated(self):
test_case = {
"image": torch.tensor([[[1, 2]], [[3, 4]]]),
"meta": {"out": ["test", "test"]},
"image_meta_dict": {"scl_slope": torch.Tensor((0.0, 0.0))},
"loss": 0.85,
}
transform = Decollated(keys=["meta", "image_meta_dict"], detach=False)
out = transform(test_case)
self.assertFalse("loss" in out)
self.assertEqual(out[0]["meta"]["out"], "test")
self.assertEqual(out[0]["image_meta_dict"]["scl_slope"], 0.0)
self.assertTrue(isinstance(out[0]["image_meta_dict"]["scl_slope"], torch.Tensor))
# decollate all data with keys=None
transform = Decollated(keys=None, detach=True)
out = transform(test_case)
self.assertEqual(out[1]["loss"], 0.85)
self.assertEqual(out[0]["meta"]["out"], "test")
self.assertEqual(out[0]["image_meta_dict"]["scl_slope"], 0.0)
self.assertTrue(isinstance(out[0]["image_meta_dict"]["scl_slope"], float))

# test list input
test_case = [
torch.tensor([[[1, 2]], [[3, 4]]]),
{"out": ["test", "test"]},
{"scl_slope": torch.Tensor((0.0, 0.0))},
0.85,
]
transform = Decollated(keys=None, detach=False)
out = transform(test_case)
# the 4th item in the list is scalar loss value
self.assertEqual(out[1][3], 0.85)
self.assertEqual(out[0][1]["out"], "test")
self.assertEqual(out[0][2]["scl_slope"], 0.0)
self.assertTrue(isinstance(out[0][2]["scl_slope"], torch.Tensor))


if __name__ == "__main__":
unittest.main()
63 changes: 63 additions & 0 deletions tests/test_handler_decollate_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from monai.engines import SupervisedEvaluator
from monai.handlers import DecollateBatch, PostProcessing
from monai.transforms import Activationsd, AsDiscreted, Compose, CopyItemsd


class TestHandlerDecollateBatch(unittest.TestCase):
def test_compute(self):
data = [
{"image": torch.tensor([[[[2.0], [3.0]]]]), "filename": ["test1"]},
{"image": torch.tensor([[[[6.0], [8.0]]]]), "filename": ["test2"]},
]

handlers = [
DecollateBatch(event="MODEL_COMPLETED"),
PostProcessing(
transform=Compose(
[
Activationsd(keys="pred", sigmoid=True),
CopyItemsd(keys="filename", times=1, names="filename_bak"),
AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, n_classes=2),
]
)
),
]
# set up engine, PostProcessing handler works together with postprocessing transforms of engine
engine = SupervisedEvaluator(
device=torch.device("cpu:0"),
val_data_loader=data,
epoch_length=2,
network=torch.nn.PReLU(),
# set decollate=False and execute some postprocessing first, then decollate in handlers
postprocessing=lambda x: dict(pred=x["pred"] + 1.0),
decollate=False,
val_handlers=handlers,
)
engine.run()

expected = torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]])

for o, e in zip(engine.state.output, expected):
torch.testing.assert_allclose(o["pred"], e)
filename = o.get("filename_bak")
if filename is not None:
self.assertEqual(filename, "test2")


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion tests/test_handler_post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
CopyItemsd(keys="filename", times=1, names="filename_bak"),
AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, n_classes=2),
]
)
),
"event": "iteration_completed",
},
True,
torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]]),
Expand Down