From 781ec5c72eaa3716526672b517e2706bb2202e06 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 12 Jul 2021 22:39:16 +0800 Subject: [PATCH 1/8] [DLMED] add DecollateBatch handler Signed-off-by: Nic Ma --- docs/source/handlers.rst | 5 ++++ monai/handlers/__init__.py | 1 + monai/handlers/decollate_batch.py | 49 +++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+) create mode 100644 monai/handlers/decollate_batch.py diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 32da051ab5..096777cdef 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -160,6 +160,11 @@ Post processing .. autoclass:: PostProcessing :members: +Decollate batch +--------------- +.. autoclass:: DecollateBatch + :members: + Utilities --------- .. automodule:: monai.handlers.utils diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 39d75064c2..42a716ced0 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -13,6 +13,7 @@ from .checkpoint_saver import CheckpointSaver from .classification_saver import ClassificationSaver from .confusion_matrix import ConfusionMatrix +from .decollate_batch import DecollateBatch from .earlystop_handler import EarlyStopHandler from .garbage_collector import GarbageCollector from .hausdorff_distance import HausdorffDistance diff --git a/monai/handlers/decollate_batch.py b/monai/handlers/decollate_batch.py new file mode 100644 index 0000000000..1282e496e3 --- /dev/null +++ b/monai/handlers/decollate_batch.py @@ -0,0 +1,49 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import TYPE_CHECKING, Callable + +from monai.config import IgniteInfo +from monai.data import decollate_batch, rep_scalar_to_batch +from monai.engines.utils import IterationEvents, engine_apply_transform +from monai.utils import min_version, optional_import + +Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + + +class DecollateBatch: + """ + Ignite handler to execute the `decollate batch` logic for `engine.state.batch` and `engine.state.output`. + So users can set `decollate=False` in the engine and execute some postprocessing logic first + then decollate the batch, otherwise, engine will decollate batch before the postprocessing. + + """ + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self) + + def __call__(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + # replicate the scalar values to make sure all the items have batch dimension, then decollate + engine.state.batch = decollate_batch(rep_scalar_to_batch(engine.state.batch), detach=True) + engine.state.output = decollate_batch(rep_scalar_to_batch(engine.state.output), detach=True) From 36225cba54079da4298467a9680b59be39e34dbc Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 12 Jul 2021 23:06:12 +0800 Subject: [PATCH 2/8] [DLMED] add unit tests Signed-off-by: Nic Ma --- monai/handlers/decollate_batch.py | 19 ++++++-- monai/handlers/postprocessing.py | 13 +++++- tests/test_handler_decollate_batch.py | 63 +++++++++++++++++++++++++++ tests/test_handler_post_processing.py | 3 +- 4 files changed, 91 insertions(+), 7 deletions(-) create mode 100644 tests/test_handler_decollate_batch.py diff --git a/monai/handlers/decollate_batch.py b/monai/handlers/decollate_batch.py index 1282e496e3..1e02950173 100644 --- a/monai/handlers/decollate_batch.py +++ b/monai/handlers/decollate_batch.py @@ -9,12 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING from monai.config import IgniteInfo from monai.data import decollate_batch, rep_scalar_to_batch -from monai.engines.utils import IterationEvents, engine_apply_transform +from monai.engines.utils import IterationEvents from monai.utils import min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") @@ -30,14 +29,26 @@ class DecollateBatch: So users can set `decollate=False` in the engine and execute some postprocessing logic first then decollate the batch, otherwise, engine will decollate batch before the postprocessing. + Args: + event: expected EVENT to attach the handler, should be "MODEL_COMPLETED" or "ITERATION_COMPLETED". + default to "MODEL_COMPLETED". + """ + def __init__(self, event: str = "MODEL_COMPLETED"): + event = event.upper() + if event not in ("MODEL_COMPLETED", "ITERATION_COMPLETED"): + raise ValueError("event should be `MODEL_COMPLETED` or `ITERATION_COMPLETED`.") + self.event = event def attach(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self) + if self.event == "MODEL_COMPLETED": + engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self) + else: + engine.add_event_handler(Events.ITERATION_COMPLETED, self) def __call__(self, engine: Engine) -> None: """ diff --git a/monai/handlers/postprocessing.py b/monai/handlers/postprocessing.py index 8732c4ad80..8e9b04b649 100644 --- a/monai/handlers/postprocessing.py +++ b/monai/handlers/postprocessing.py @@ -30,21 +30,30 @@ class PostProcessing: """ - def __init__(self, transform: Callable) -> None: + def __init__(self, transform: Callable, event: str = "MODEL_COMPLETED") -> None: """ Args: transform: callable function to execute on the `engine.state.batch` and `engine.state.output`. can also be composed transforms. + event: expected EVENT to attach the handler, should be "MODEL_COMPLETED" or "ITERATION_COMPLETED". + default to "MODEL_COMPLETED". """ self.transform = transform + event = event.upper() + if event not in ("MODEL_COMPLETED", "ITERATION_COMPLETED"): + raise ValueError("event should be `MODEL_COMPLETED` or `ITERATION_COMPLETED`.") + self.event = event def attach(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self) + if self.event == "MODEL_COMPLETED": + engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self) + else: + engine.add_event_handler(Events.ITERATION_COMPLETED, self) def __call__(self, engine: Engine) -> None: """ diff --git a/tests/test_handler_decollate_batch.py b/tests/test_handler_decollate_batch.py new file mode 100644 index 0000000000..bc74cf5328 --- /dev/null +++ b/tests/test_handler_decollate_batch.py @@ -0,0 +1,63 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.engines import SupervisedEvaluator +from monai.handlers import DecollateBatch, PostProcessing +from monai.transforms import Activationsd, AsDiscreted, Compose, CopyItemsd + + +class TestHandlerDecollateBatch(unittest.TestCase): + def test_compute(self): + data = [ + {"image": torch.tensor([[[[2.0], [3.0]]]]), "filename": ["test1"]}, + {"image": torch.tensor([[[[6.0], [8.0]]]]), "filename": ["test2"]}, + ] + + handlers = [ + DecollateBatch(event="MODEL_COMPLETED"), + PostProcessing( + transform=Compose( + [ + Activationsd(keys="pred", sigmoid=True), + CopyItemsd(keys="filename", times=1, names="filename_bak"), + AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, n_classes=2), + ] + ) + ), + ] + # set up engine, PostProcessing handler works together with postprocessing transforms of engine + engine = SupervisedEvaluator( + device=torch.device("cpu:0"), + val_data_loader=data, + epoch_length=2, + network=torch.nn.PReLU(), + # set decollate=False and execute some postprocessing first, then decollate in handlers + postprocessing=lambda x: dict(pred=x["pred"] + 1.0), + decollate=False, + val_handlers=handlers, + ) + engine.run() + + expected = torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]]) + + for o, e in zip(engine.state.output, expected): + torch.testing.assert_allclose(o["pred"], e) + filename = o.get("filename_bak") + if filename is not None: + self.assertEqual(filename, "test2") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_post_processing.py b/tests/test_handler_post_processing.py index 89638cdac9..753dfdc539 100644 --- a/tests/test_handler_post_processing.py +++ b/tests/test_handler_post_processing.py @@ -28,7 +28,8 @@ CopyItemsd(keys="filename", times=1, names="filename_bak"), AsDiscreted(keys="pred", threshold_values=True, to_onehot=True, n_classes=2), ] - ) + ), + "event": "iteration_completed", }, torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]]), ] From 82c630ad983742cb1de9442a8110217a2cbe540f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 13 Jul 2021 07:19:04 +0800 Subject: [PATCH 3/8] [DLMED] enhance doc-string Signed-off-by: Nic Ma --- monai/handlers/decollate_batch.py | 2 +- monai/handlers/postprocessing.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/handlers/decollate_batch.py b/monai/handlers/decollate_batch.py index 1e02950173..326e4e16ab 100644 --- a/monai/handlers/decollate_batch.py +++ b/monai/handlers/decollate_batch.py @@ -26,7 +26,7 @@ class DecollateBatch: """ Ignite handler to execute the `decollate batch` logic for `engine.state.batch` and `engine.state.output`. - So users can set `decollate=False` in the engine and execute some postprocessing logic first + Typical usage is to set `decollate=False` in the engine and execute some postprocessing logic first then decollate the batch, otherwise, engine will decollate batch before the postprocessing. Args: diff --git a/monai/handlers/postprocessing.py b/monai/handlers/postprocessing.py index 8e9b04b649..8c39631f1a 100644 --- a/monai/handlers/postprocessing.py +++ b/monai/handlers/postprocessing.py @@ -27,6 +27,8 @@ class PostProcessing: """ Ignite handler to execute additional post processing after the post processing in engines. So users can insert other handlers between engine postprocessing and this post processing handler. + If using components from `monai.transforms` as the `transform`, recommend to decollate `engine.state.batch` + and `engine.state.batch` in the engine(set `decollate=True`) or in the `DecollateBatch` handler first. """ From 67b4e34dfd14fc89d0888e25adeadc22d83c5475 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Mon, 12 Jul 2021 23:24:08 +0000 Subject: [PATCH 4/8] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/handlers/decollate_batch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/handlers/decollate_batch.py b/monai/handlers/decollate_batch.py index 326e4e16ab..2039e53093 100644 --- a/monai/handlers/decollate_batch.py +++ b/monai/handlers/decollate_batch.py @@ -34,6 +34,7 @@ class DecollateBatch: default to "MODEL_COMPLETED". """ + def __init__(self, event: str = "MODEL_COMPLETED"): event = event.upper() if event not in ("MODEL_COMPLETED", "ITERATION_COMPLETED"): From 591fa0995c53df690912b66441afcdeb98888746 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 13 Jul 2021 07:51:05 +0800 Subject: [PATCH 5/8] [DLMED] skip in min tests Signed-off-by: Nic Ma --- tests/min_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/min_tests.py b/tests/min_tests.py index a3f140b856..1cd54f35d0 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -134,6 +134,7 @@ def run_testsuit(): "test_unetr", "test_unetr_block", "test_vit", + "test_handler_decollate_batch", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From 70a96657a6ff2601b25c2ee14fe7cd4b466613f1 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 13 Jul 2021 23:48:12 +0800 Subject: [PATCH 6/8] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/engines/workflow.py | 8 +-- monai/handlers/decollate_batch.py | 59 ++++++++++++++++++--- monai/transforms/inverse_batch_transform.py | 46 ++++++++++++---- tests/test_decollate.py | 36 +++++++++++++ 4 files changed, 129 insertions(+), 20 deletions(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 1d76fcaf83..2d3f069966 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -18,9 +18,8 @@ from torch.utils.data.distributed import DistributedSampler from monai.config import IgniteInfo -from monai.data import decollate_batch, rep_scalar_to_batch from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch -from monai.transforms import Transform +from monai.transforms import Decollated, Transform from monai.utils import ensure_tuple, min_version, optional_import from .utils import engine_apply_transform @@ -186,8 +185,9 @@ def _register_decollate(self): @self.on(IterationEvents.MODEL_COMPLETED) def _decollate_data(engine: Engine) -> None: # replicate the scalar values to make sure all the items have batch dimension, then decollate - engine.state.batch = decollate_batch(rep_scalar_to_batch(engine.state.batch), detach=True) - engine.state.output = decollate_batch(rep_scalar_to_batch(engine.state.output), detach=True) + transform = Decollated(keys=None, detach=True, rep_scalar=True) + engine.state.batch = transform(engine.state.batch) + engine.state.output = transform(engine.state.output) def _register_postprocessing(self, posttrans: Callable): """ diff --git a/monai/handlers/decollate_batch.py b/monai/handlers/decollate_batch.py index 2039e53093..061bda25d4 100644 --- a/monai/handlers/decollate_batch.py +++ b/monai/handlers/decollate_batch.py @@ -9,11 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional -from monai.config import IgniteInfo -from monai.data import decollate_batch, rep_scalar_to_batch +from monai.config import IgniteInfo, KeysCollection from monai.engines.utils import IterationEvents +from monai.transforms import Decollated from monai.utils import min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") @@ -32,15 +32,59 @@ class DecollateBatch: Args: event: expected EVENT to attach the handler, should be "MODEL_COMPLETED" or "ITERATION_COMPLETED". default to "MODEL_COMPLETED". + detach: whether to detach the tensors. scalars tensors will be detached into number types + instead of torch tensors. + rep_scalar: whether to replicate the scalar values to every decollated item of the list. + decollate_batch: whether to decollate `engine.state.batch` of ignite engine. + batch_keys: if `decollate_batch=True`, specify the keys of the corresponding items to decollate + in `engine.state.batch`, note that it will delete other keys not specified. if None, + will decollate all the keys. + decollate_output: whether to decollate `engine.state.output` of ignite engine. + output_keys: if `decollate_output=True`, specify the keys of the corresponding items to decollate + in `engine.state.output`, note that it will delete other keys not specified. if None, + will decollate all the keys. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, event: str = "MODEL_COMPLETED"): + def __init__( + self, + event: str = "MODEL_COMPLETED", + detach: bool = True, + rep_scalar: bool = True, + decollate_batch: bool = True, + batch_keys: Optional[KeysCollection] = None, + decollate_output: bool = True, + output_keys: Optional[KeysCollection] = None, + allow_missing_keys: bool = False, + ): event = event.upper() if event not in ("MODEL_COMPLETED", "ITERATION_COMPLETED"): raise ValueError("event should be `MODEL_COMPLETED` or `ITERATION_COMPLETED`.") self.event = event + self.batch_transform = ( + Decollated( + keys=batch_keys, + detach=detach, + rep_scalar=rep_scalar, + allow_missing_keys=allow_missing_keys, + ) + if decollate_batch + else None + ) + + self.output_transform = ( + Decollated( + keys=output_keys, + detach=detach, + rep_scalar=rep_scalar, + allow_missing_keys=allow_missing_keys, + ) + if decollate_output + else None + ) + def attach(self, engine: Engine) -> None: """ Args: @@ -56,6 +100,7 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - # replicate the scalar values to make sure all the items have batch dimension, then decollate - engine.state.batch = decollate_batch(rep_scalar_to_batch(engine.state.batch), detach=True) - engine.state.output = decollate_batch(rep_scalar_to_batch(engine.state.output), detach=True) + if self.batch_transform is not None: + engine.state.batch = self.batch_transform(engine.state.batch) + if self.output_transform is not None: + engine.state.output = self.output_transform(engine.state.output) diff --git a/monai/transforms/inverse_batch_transform.py b/monai/transforms/inverse_batch_transform.py index c6dad2fcd0..cb07c9bd7c 100644 --- a/monai/transforms/inverse_batch_transform.py +++ b/monai/transforms/inverse_batch_transform.py @@ -10,13 +10,15 @@ # limitations under the License. import warnings -from typing import Any, Callable, Dict, Optional, Sequence +from typing import Any, Callable, Dict, List, Optional, Sequence, Union +import torch from torch.utils.data import Dataset from torch.utils.data.dataloader import DataLoader as TorchDataLoader +from monai.config import KeysCollection from monai.data.dataloader import DataLoader -from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate +from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate, rep_scalar_to_batch from monai.transforms.croppad.batch import PadListDataCollate from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Transform @@ -103,18 +105,44 @@ def __call__(self, data: Dict[str, Any]) -> Any: class Decollated(MapTransform): """ - Decollate a batch of data. - Note that unlike most MapTransforms, this will decollate all data, so keys are not needed. + Decollate a batch of data, if input a dictionary, it can also support to only decollate specified keys. + Note that unlike most MapTransforms, it will delete other keys not specified and if keys=None, will decollate + all the data in the input. Args: + keys: keys of the corresponding items to decollate, note that it will delete other keys not specified. + if None, will decollate all the keys. see also: :py:class:`monai.transforms.compose.MapTransform`. detach: whether to detach the tensors. Scalars tensors will be detached into number types instead of torch tensors. + rep_scalar: whether to replicate the scalar values to every decollated item of the list. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, keys="", detach: bool = True) -> None: - super().__init__(keys=keys) + def __init__( + self, + keys: Optional[KeysCollection] = None, + detach: bool = True, + rep_scalar: bool = True, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) self.detach = detach - - def __call__(self, data: dict): - return decollate_batch(data, detach=self.detach) + self.rep_scalar = rep_scalar + + def __call__(self, data: Union[Dict, List]): + d: Union[Dict, List] + if len(self.keys) == 1 and self.keys[0] is None: + # it doesn't support `None` as the key + d = data + else: + if not isinstance(data, dict): + raise TypeError("input data is not a dictionary, but specified keys to decollate.") + d = {} + for key in self.key_iterator(data): + d[key] = data[key] + + if self.rep_scalar: + d = rep_scalar_to_batch(d) + + return decollate_batch(d, detach=self.detach) diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 7d4532fbfd..ca2aa93430 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -210,6 +210,42 @@ def test_dict_examples(self): out = decollate_batch(test_case, detach=False) self.assertEqual(out[0]["out"], "test") + def test_decollated(self): + test_case = { + "image": torch.tensor([[[1, 2]], [[3, 4]]]), + "meta": {"out": ["test", "test"]}, + "image_meta_dict": {"scl_slope": torch.Tensor((0.0, 0.0))}, + "loss": 0.85, + } + transform = Decollated(keys=["meta", "image_meta_dict"], detach=False, rep_scalar=False) + out = transform(test_case) + self.assertFalse("loss" in out) + self.assertEqual(out[0]["meta"]["out"], "test") + self.assertEqual(out[0]["image_meta_dict"]["scl_slope"], 0.0) + self.assertTrue(isinstance(out[0]["image_meta_dict"]["scl_slope"], torch.Tensor)) + # decollate all data with rep_scalar=True + transform = Decollated(keys=None, detach=True, rep_scalar=True) + out = transform(test_case) + self.assertEqual(out[1]["loss"], 0.85) + self.assertEqual(out[0]["meta"]["out"], "test") + self.assertEqual(out[0]["image_meta_dict"]["scl_slope"], 0.0) + self.assertTrue(isinstance(out[0]["image_meta_dict"]["scl_slope"], float)) + + # test list input + test_case = [ + torch.tensor([[[1, 2]], [[3, 4]]]), + {"out": ["test", "test"]}, + {"scl_slope": torch.Tensor((0.0, 0.0))}, + 0.85, + ] + transform = Decollated(keys=None, detach=False, rep_scalar=True) + out = transform(test_case) + # the 4th item in the list is scalar loss value + self.assertEqual(out[1][3], 0.85) + self.assertEqual(out[0][1]["out"], "test") + self.assertEqual(out[0][2]["scl_slope"], 0.0) + self.assertTrue(isinstance(out[0][2]["scl_slope"], torch.Tensor)) + if __name__ == "__main__": unittest.main() From c9b696c897c84e82bc1886027a9aef4eb33e836b Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 13 Jul 2021 23:56:40 +0800 Subject: [PATCH 7/8] [DLMED] fix flake8 Signed-off-by: Nic Ma --- monai/transforms/inverse_batch_transform.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/transforms/inverse_batch_transform.py b/monai/transforms/inverse_batch_transform.py index cb07c9bd7c..37f2f8382b 100644 --- a/monai/transforms/inverse_batch_transform.py +++ b/monai/transforms/inverse_batch_transform.py @@ -12,7 +12,6 @@ import warnings from typing import Any, Callable, Dict, List, Optional, Sequence, Union -import torch from torch.utils.data import Dataset from torch.utils.data.dataloader import DataLoader as TorchDataLoader From 182df874fa168ba70edb2a78da394c21a598e25b Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 14 Jul 2021 06:30:53 +0800 Subject: [PATCH 8/8] [DLMED] remove rep_scalar option Signed-off-by: Nic Ma --- monai/engines/workflow.py | 2 +- monai/handlers/decollate_batch.py | 20 ++++---------------- monai/transforms/inverse_batch_transform.py | 9 ++------- tests/test_decollate.py | 8 ++++---- 4 files changed, 11 insertions(+), 28 deletions(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 2d3f069966..e72118a213 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -185,7 +185,7 @@ def _register_decollate(self): @self.on(IterationEvents.MODEL_COMPLETED) def _decollate_data(engine: Engine) -> None: # replicate the scalar values to make sure all the items have batch dimension, then decollate - transform = Decollated(keys=None, detach=True, rep_scalar=True) + transform = Decollated(keys=None, detach=True) engine.state.batch = transform(engine.state.batch) engine.state.output = transform(engine.state.output) diff --git a/monai/handlers/decollate_batch.py b/monai/handlers/decollate_batch.py index 061bda25d4..4e99fc6f04 100644 --- a/monai/handlers/decollate_batch.py +++ b/monai/handlers/decollate_batch.py @@ -34,15 +34,14 @@ class DecollateBatch: default to "MODEL_COMPLETED". detach: whether to detach the tensors. scalars tensors will be detached into number types instead of torch tensors. - rep_scalar: whether to replicate the scalar values to every decollated item of the list. decollate_batch: whether to decollate `engine.state.batch` of ignite engine. batch_keys: if `decollate_batch=True`, specify the keys of the corresponding items to decollate in `engine.state.batch`, note that it will delete other keys not specified. if None, - will decollate all the keys. + will decollate all the keys. it replicates the scalar values to every item of the decollated list. decollate_output: whether to decollate `engine.state.output` of ignite engine. output_keys: if `decollate_output=True`, specify the keys of the corresponding items to decollate in `engine.state.output`, note that it will delete other keys not specified. if None, - will decollate all the keys. + will decollate all the keys. it replicates the scalar values to every item of the decollated list. allow_missing_keys: don't raise exception if key is missing. """ @@ -51,7 +50,6 @@ def __init__( self, event: str = "MODEL_COMPLETED", detach: bool = True, - rep_scalar: bool = True, decollate_batch: bool = True, batch_keys: Optional[KeysCollection] = None, decollate_output: bool = True, @@ -64,23 +62,13 @@ def __init__( self.event = event self.batch_transform = ( - Decollated( - keys=batch_keys, - detach=detach, - rep_scalar=rep_scalar, - allow_missing_keys=allow_missing_keys, - ) + Decollated(keys=batch_keys, detach=detach, allow_missing_keys=allow_missing_keys) if decollate_batch else None ) self.output_transform = ( - Decollated( - keys=output_keys, - detach=detach, - rep_scalar=rep_scalar, - allow_missing_keys=allow_missing_keys, - ) + Decollated(keys=output_keys, detach=detach, allow_missing_keys=allow_missing_keys) if decollate_output else None ) diff --git a/monai/transforms/inverse_batch_transform.py b/monai/transforms/inverse_batch_transform.py index 37f2f8382b..d9c6790840 100644 --- a/monai/transforms/inverse_batch_transform.py +++ b/monai/transforms/inverse_batch_transform.py @@ -107,13 +107,13 @@ class Decollated(MapTransform): Decollate a batch of data, if input a dictionary, it can also support to only decollate specified keys. Note that unlike most MapTransforms, it will delete other keys not specified and if keys=None, will decollate all the data in the input. + And it replicates the scalar values to every item of the decollated list. Args: keys: keys of the corresponding items to decollate, note that it will delete other keys not specified. if None, will decollate all the keys. see also: :py:class:`monai.transforms.compose.MapTransform`. detach: whether to detach the tensors. Scalars tensors will be detached into number types instead of torch tensors. - rep_scalar: whether to replicate the scalar values to every decollated item of the list. allow_missing_keys: don't raise exception if key is missing. """ @@ -122,12 +122,10 @@ def __init__( self, keys: Optional[KeysCollection] = None, detach: bool = True, - rep_scalar: bool = True, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) self.detach = detach - self.rep_scalar = rep_scalar def __call__(self, data: Union[Dict, List]): d: Union[Dict, List] @@ -141,7 +139,4 @@ def __call__(self, data: Union[Dict, List]): for key in self.key_iterator(data): d[key] = data[key] - if self.rep_scalar: - d = rep_scalar_to_batch(d) - - return decollate_batch(d, detach=self.detach) + return decollate_batch(rep_scalar_to_batch(d), detach=self.detach) diff --git a/tests/test_decollate.py b/tests/test_decollate.py index ca2aa93430..521d263663 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -217,14 +217,14 @@ def test_decollated(self): "image_meta_dict": {"scl_slope": torch.Tensor((0.0, 0.0))}, "loss": 0.85, } - transform = Decollated(keys=["meta", "image_meta_dict"], detach=False, rep_scalar=False) + transform = Decollated(keys=["meta", "image_meta_dict"], detach=False) out = transform(test_case) self.assertFalse("loss" in out) self.assertEqual(out[0]["meta"]["out"], "test") self.assertEqual(out[0]["image_meta_dict"]["scl_slope"], 0.0) self.assertTrue(isinstance(out[0]["image_meta_dict"]["scl_slope"], torch.Tensor)) - # decollate all data with rep_scalar=True - transform = Decollated(keys=None, detach=True, rep_scalar=True) + # decollate all data with keys=None + transform = Decollated(keys=None, detach=True) out = transform(test_case) self.assertEqual(out[1]["loss"], 0.85) self.assertEqual(out[0]["meta"]["out"], "test") @@ -238,7 +238,7 @@ def test_decollated(self): {"scl_slope": torch.Tensor((0.0, 0.0))}, 0.85, ] - transform = Decollated(keys=None, detach=False, rep_scalar=True) + transform = Decollated(keys=None, detach=False) out = transform(test_case) # the 4th item in the list is scalar loss value self.assertEqual(out[1][3], 0.85)