diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 32da051ab5..096777cdef 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -160,6 +160,11 @@ Post processing .. autoclass:: PostProcessing :members: +Decollate batch +--------------- +.. autoclass:: DecollateBatch + :members: + Utilities --------- .. automodule:: monai.handlers.utils diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 1d76fcaf83..e72118a213 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -18,9 +18,8 @@ from torch.utils.data.distributed import DistributedSampler from monai.config import IgniteInfo -from monai.data import decollate_batch, rep_scalar_to_batch from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch -from monai.transforms import Transform +from monai.transforms import Decollated, Transform from monai.utils import ensure_tuple, min_version, optional_import from .utils import engine_apply_transform @@ -186,8 +185,9 @@ def _register_decollate(self): @self.on(IterationEvents.MODEL_COMPLETED) def _decollate_data(engine: Engine) -> None: # replicate the scalar values to make sure all the items have batch dimension, then decollate - engine.state.batch = decollate_batch(rep_scalar_to_batch(engine.state.batch), detach=True) - engine.state.output = decollate_batch(rep_scalar_to_batch(engine.state.output), detach=True) + transform = Decollated(keys=None, detach=True) + engine.state.batch = transform(engine.state.batch) + engine.state.output = transform(engine.state.output) def _register_postprocessing(self, posttrans: Callable): """ diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 39d75064c2..42a716ced0 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -13,6 +13,7 @@ from .checkpoint_saver import CheckpointSaver from .classification_saver import ClassificationSaver from .confusion_matrix import ConfusionMatrix +from .decollate_batch import DecollateBatch from .earlystop_handler import EarlyStopHandler from .garbage_collector import GarbageCollector from .hausdorff_distance import HausdorffDistance diff --git a/monai/handlers/decollate_batch.py b/monai/handlers/decollate_batch.py new file mode 100644 index 0000000000..4e99fc6f04 --- /dev/null +++ b/monai/handlers/decollate_batch.py @@ -0,0 +1,94 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Optional + +from monai.config import IgniteInfo, KeysCollection +from monai.engines.utils import IterationEvents +from monai.transforms import Decollated +from monai.utils import min_version, optional_import + +Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + + +class DecollateBatch: + """ + Ignite handler to execute the `decollate batch` logic for `engine.state.batch` and `engine.state.output`. + Typical usage is to set `decollate=False` in the engine and execute some postprocessing logic first + then decollate the batch, otherwise, engine will decollate batch before the postprocessing. + + Args: + event: expected EVENT to attach the handler, should be "MODEL_COMPLETED" or "ITERATION_COMPLETED". + default to "MODEL_COMPLETED". + detach: whether to detach the tensors. scalars tensors will be detached into number types + instead of torch tensors. + decollate_batch: whether to decollate `engine.state.batch` of ignite engine. + batch_keys: if `decollate_batch=True`, specify the keys of the corresponding items to decollate + in `engine.state.batch`, note that it will delete other keys not specified. if None, + will decollate all the keys. it replicates the scalar values to every item of the decollated list. + decollate_output: whether to decollate `engine.state.output` of ignite engine. + output_keys: if `decollate_output=True`, specify the keys of the corresponding items to decollate + in `engine.state.output`, note that it will delete other keys not specified. if None, + will decollate all the keys. it replicates the scalar values to every item of the decollated list. + allow_missing_keys: don't raise exception if key is missing. + + """ + + def __init__( + self, + event: str = "MODEL_COMPLETED", + detach: bool = True, + decollate_batch: bool = True, + batch_keys: Optional[KeysCollection] = None, + decollate_output: bool = True, + output_keys: Optional[KeysCollection] = None, + allow_missing_keys: bool = False, + ): + event = event.upper() + if event not in ("MODEL_COMPLETED", "ITERATION_COMPLETED"): + raise ValueError("event should be `MODEL_COMPLETED` or `ITERATION_COMPLETED`.") + self.event = event + + self.batch_transform = ( + Decollated(keys=batch_keys, detach=detach, allow_missing_keys=allow_missing_keys) + if decollate_batch + else None + ) + + self.output_transform = ( + Decollated(keys=output_keys, detach=detach, allow_missing_keys=allow_missing_keys) + if decollate_output + else None + ) + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + if self.event == "MODEL_COMPLETED": + engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self) + else: + engine.add_event_handler(Events.ITERATION_COMPLETED, self) + + def __call__(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + if self.batch_transform is not None: + engine.state.batch = self.batch_transform(engine.state.batch) + if self.output_transform is not None: + engine.state.output = self.output_transform(engine.state.output) diff --git a/monai/handlers/postprocessing.py b/monai/handlers/postprocessing.py index cb5342456c..05c6bd414d 100644 --- a/monai/handlers/postprocessing.py +++ b/monai/handlers/postprocessing.py @@ -26,24 +26,35 @@ class PostProcessing: """ Ignite handler to execute additional post processing after the post processing in engines. So users can insert other handlers between engine postprocessing and this post processing handler. + If using components from `monai.transforms` as the `transform`, recommend to decollate `engine.state.batch` + and `engine.state.batch` in the engine(set `decollate=True`) or in the `DecollateBatch` handler first. """ - def __init__(self, transform: Callable) -> None: + def __init__(self, transform: Callable, event: str = "MODEL_COMPLETED") -> None: """ Args: transform: callable function to execute on the `engine.state.batch` and `engine.state.output`. can also be composed transforms. + event: expected EVENT to attach the handler, should be "MODEL_COMPLETED" or "ITERATION_COMPLETED". + default to "MODEL_COMPLETED". """ self.transform = transform + event = event.upper() + if event not in ("MODEL_COMPLETED", "ITERATION_COMPLETED"): + raise ValueError("event should be `MODEL_COMPLETED` or `ITERATION_COMPLETED`.") + self.event = event def attach(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self) + if self.event == "MODEL_COMPLETED": + engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self) + else: + engine.add_event_handler(Events.ITERATION_COMPLETED, self) def __call__(self, engine: Engine) -> None: """ diff --git a/monai/transforms/inverse_batch_transform.py b/monai/transforms/inverse_batch_transform.py index c6dad2fcd0..d9c6790840 100644 --- a/monai/transforms/inverse_batch_transform.py +++ b/monai/transforms/inverse_batch_transform.py @@ -10,13 +10,14 @@ # limitations under the License. import warnings -from typing import Any, Callable, Dict, Optional, Sequence +from typing import Any, Callable, Dict, List, Optional, Sequence, Union from torch.utils.data import Dataset from torch.utils.data.dataloader import DataLoader as TorchDataLoader +from monai.config import KeysCollection from monai.data.dataloader import DataLoader -from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate +from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate, rep_scalar_to_batch from monai.transforms.croppad.batch import PadListDataCollate from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Transform @@ -103,18 +104,39 @@ def __call__(self, data: Dict[str, Any]) -> Any: class Decollated(MapTransform): """ - Decollate a batch of data. - Note that unlike most MapTransforms, this will decollate all data, so keys are not needed. + Decollate a batch of data, if input a dictionary, it can also support to only decollate specified keys. + Note that unlike most MapTransforms, it will delete other keys not specified and if keys=None, will decollate + all the data in the input. + And it replicates the scalar values to every item of the decollated list. Args: + keys: keys of the corresponding items to decollate, note that it will delete other keys not specified. + if None, will decollate all the keys. see also: :py:class:`monai.transforms.compose.MapTransform`. detach: whether to detach the tensors. Scalars tensors will be detached into number types instead of torch tensors. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys="", detach: bool = True) -> None: - super().__init__(keys=keys) + def __init__( + self, + keys: Optional[KeysCollection] = None, + detach: bool = True, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) self.detach = detach - def __call__(self, data: dict): - return decollate_batch(data, detach=self.detach) + def __call__(self, data: Union[Dict, List]): + d: Union[Dict, List] + if len(self.keys) == 1 and self.keys[0] is None: + # it doesn't support `None` as the key + d = data + else: + if not isinstance(data, dict): + raise TypeError("input data is not a dictionary, but specified keys to decollate.") + d = {} + for key in self.key_iterator(data): + d[key] = data[key] + + return decollate_batch(rep_scalar_to_batch(d), detach=self.detach) diff --git a/tests/min_tests.py b/tests/min_tests.py index a3f140b856..1cd54f35d0 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -134,6 +134,7 @@ def run_testsuit(): "test_unetr", "test_unetr_block", "test_vit", + "test_handler_decollate_batch", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 7d4532fbfd..521d263663 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -210,6 +210,42 @@ def test_dict_examples(self): out = decollate_batch(test_case, detach=False) self.assertEqual(out[0]["out"], "test") + def test_decollated(self): + test_case = { + "image": torch.tensor([[[1, 2]], [[3, 4]]]), + "meta": {"out": ["test", "test"]}, + "image_meta_dict": {"scl_slope": torch.Tensor((0.0, 0.0))}, + "loss": 0.85, + } + transform = Decollated(keys=["meta", "image_meta_dict"], detach=False) + out = transform(test_case) + self.assertFalse("loss" in out) + self.assertEqual(out[0]["meta"]["out"], "test") + self.assertEqual(out[0]["image_meta_dict"]["scl_slope"], 0.0) + self.assertTrue(isinstance(out[0]["image_meta_dict"]["scl_slope"], torch.Tensor)) + # decollate all data with keys=None + transform = Decollated(keys=None, detach=True) + out = transform(test_case) + self.assertEqual(out[1]["loss"], 0.85) + self.assertEqual(out[0]["meta"]["out"], "test") + self.assertEqual(out[0]["image_meta_dict"]["scl_slope"], 0.0) + self.assertTrue(isinstance(out[0]["image_meta_dict"]["scl_slope"], float)) + + # test list input + test_case = [ + torch.tensor([[[1, 2]], [[3, 4]]]), + {"out": ["test", "test"]}, + {"scl_slope": torch.Tensor((0.0, 0.0))}, + 0.85, + ] + transform = Decollated(keys=None, detach=False) + out = transform(test_case) + # the 4th item in the list is scalar loss value + self.assertEqual(out[1][3], 0.85) + self.assertEqual(out[0][1]["out"], "test") + self.assertEqual(out[0][2]["scl_slope"], 0.0) + self.assertTrue(isinstance(out[0][2]["scl_slope"], torch.Tensor)) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_handler_decollate_batch.py b/tests/test_handler_decollate_batch.py new file mode 100644 index 0000000000..bc74cf5328 --- /dev/null +++ b/tests/test_handler_decollate_batch.py @@ -0,0 +1,63 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.engines import SupervisedEvaluator +from monai.handlers import DecollateBatch, PostProcessing +from monai.transforms import Activationsd, AsDiscreted, Compose, CopyItemsd + + +class TestHandlerDecollateBatch(unittest.TestCase): + def test_compute(self): + data = [ + {"image": torch.tensor([[[[2.0], [3.0]]]]), "filename": ["test1"]}, + {"image": torch.tensor([[[[6.0], [8.0]]]]), "filename": ["test2"]}, + ] + + handlers = [ + DecollateBatch(event="MODEL_COMPLETED"), + PostProcessing( + transform=Compose( + [ + Activationsd(keys="pred", sigmoid=True), + CopyItemsd(keys="filename", times=1, names="filename_bak"), + AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, n_classes=2), + ] + ) + ), + ] + # set up engine, PostProcessing handler works together with postprocessing transforms of engine + engine = SupervisedEvaluator( + device=torch.device("cpu:0"), + val_data_loader=data, + epoch_length=2, + network=torch.nn.PReLU(), + # set decollate=False and execute some postprocessing first, then decollate in handlers + postprocessing=lambda x: dict(pred=x["pred"] + 1.0), + decollate=False, + val_handlers=handlers, + ) + engine.run() + + expected = torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]]) + + for o, e in zip(engine.state.output, expected): + torch.testing.assert_allclose(o["pred"], e) + filename = o.get("filename_bak") + if filename is not None: + self.assertEqual(filename, "test2") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_post_processing.py b/tests/test_handler_post_processing.py index 589adfb35d..552cde9eb1 100644 --- a/tests/test_handler_post_processing.py +++ b/tests/test_handler_post_processing.py @@ -28,7 +28,8 @@ CopyItemsd(keys="filename", times=1, names="filename_bak"), AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, n_classes=2), ] - ) + ), + "event": "iteration_completed", }, True, torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]]),