From 10db7811c3e877219dae927c1b798a83884a91dc Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 20 May 2021 16:54:05 +0800 Subject: [PATCH 1/3] [DLMED] temp combine batch and output Signed-off-by: Nic Ma --- monai/engines/workflow.py | 15 +++++++++++++- tests/test_integration_workflows.py | 32 +++++++++++++++++++++-------- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 4018dabc40..9486972eac 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -162,7 +162,20 @@ def _register_post_transforms(self, posttrans: Callable): @self.on(IterationEvents.MODEL_COMPLETED) def run_post_transform(engine: Engine) -> None: - engine.state.output = apply_transform(posttrans, engine.state.output) + if isinstance(engine.state.batch, dict) and isinstance(engine.state.output, dict): + # if `batch` and `output` are dictionaries, temporarily combine them for post transforms + data = dict(engine.state.batch) + data.update(engine.state.output) + data = apply_transform(posttrans, data) + for k, v in data.items(): + # split the output data of post transforms into `output` and `batch`, + # `batch` should be read-only, so save the generated key-value into `output` + if k in engine.state.output or k not in engine.state.batch: + engine.state.output[k] = v + else: + engine.state.batch[k] = v + else: + engine.state.output = apply_transform(posttrans, engine.state.output) def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None): """ diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 00d097b2b6..2184c29b99 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -48,6 +48,7 @@ LoadImaged, RandCropByPosNegLabeld, RandRotate90d, + SaveImaged, ScaleIntensityd, ToTensord, ) @@ -237,6 +238,14 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), + # test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch` + SaveImaged( + keys="pred", + meta_keys="image_meta_dict", + output_dir=root_dir, + output_postfix="seg_transform", + save_batch=True, + ), ] ) val_handlers = [ @@ -244,6 +253,7 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}), SegmentationSaver( output_dir=root_dir, + output_postfix="seg_handler", batch_transform=lambda batch: batch["image_meta_dict"], output_transform=lambda output: output["pred"], ), @@ -308,14 +318,20 @@ def train_and_infer(self, idx=0): self.assertTrue(test_integration_value(TASK, key="infer_metric", data=infer_metric, rtol=1e-2)) results.append(best_metric) results.append(infer_metric) - output_files = sorted(glob(os.path.join(self.data_dir, "img*", "*.nii.gz"))) - for output in output_files: - ave = np.mean(nib.load(output).get_fdata()) - results.append(ave) - if idx == 2: - self.assertTrue(test_integration_value(TASK, key="output_sums_2", data=results[2:], rtol=1e-2)) - else: - self.assertTrue(test_integration_value(TASK, key="output_sums", data=results[2:], rtol=1e-2)) + + def _test_saved_files(postfix): + output_files = sorted(glob(os.path.join(self.data_dir, "img*", f"*{postfix}.nii.gz"))) + values = [] + for output in output_files: + ave = np.mean(nib.load(output).get_fdata()) + values.append(ave) + if idx == 2: + self.assertTrue(test_integration_value(TASK, key="output_sums_2", data=values, rtol=1e-2)) + else: + self.assertTrue(test_integration_value(TASK, key="output_sums", data=values, rtol=1e-2)) + + _test_saved_files(postfix="seg_handler") + _test_saved_files(postfix="seg_transform") try: os.remove(model_file) except Exception as e: From 328cd13005355fe2e4f06b615b930afac8321a5f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 21 May 2021 11:07:00 +0800 Subject: [PATCH 2/3] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/engines/__init__.py | 9 ++++++++- monai/engines/utils.py | 28 +++++++++++++++++++++++++++- monai/engines/workflow.py | 21 ++++++--------------- 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index d3a14f6104..57605c39ca 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -12,4 +12,11 @@ from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer from .trainer import GanTrainer, SupervisedTrainer, Trainer -from .utils import GanKeys, IterationEvents, default_make_latent, default_prepare_batch, get_devices_spec +from .utils import ( + GanKeys, + IterationEvents, + default_make_latent, + default_prepare_batch, + get_devices_spec, + engine_apply_transform, +) diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 265a63ee0c..071f5e70da 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -9,10 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union import torch +from monai.transforms import apply_transform from monai.utils import exact_version, optional_import from monai.utils.enums import CommonKeys @@ -27,6 +28,7 @@ "get_devices_spec", "default_prepare_batch", "default_make_latent", + "engine_apply_transform", ] @@ -124,3 +126,27 @@ def default_make_latent( non_blocking: bool = False, ) -> torch.Tensor: return torch.randn(num_latents, latent_size).to(device=device, non_blocking=non_blocking) + + +def engine_apply_transform(batch: Any, output: Any, transform: Callable): + """ + Apply transform for the engine.state.batch and engine.state.output. + If `batch` and `output` are dictionaries, temporarily combine them for the transform, + otherwise, apply the transform for `output` data only. + + """ + if isinstance(batch, dict) and isinstance(output, dict): + data = dict(batch) + data.update(output) + data = apply_transform(transform, data) + for k, v in data.items(): + # split the output data of post transforms into `output` and `batch`, + # `batch` should be read-only, so save the generated key-value into `output` + if k in output or k not in batch: + output[k] = v + else: + batch[k] = v + else: + output = apply_transform(transform, output) + + return batch, output diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 9486972eac..2ee35a46c2 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -17,8 +17,8 @@ from torch.utils.data.distributed import DistributedSampler from monai.engines.utils import IterationEvents, default_prepare_batch -from monai.transforms import apply_transform from monai.utils import ensure_tuple, exact_version, optional_import +from .utils import engine_apply_transform IgniteEngine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") State, _ = optional_import("ignite.engine", "0.4.4", exact_version, "State") @@ -162,20 +162,11 @@ def _register_post_transforms(self, posttrans: Callable): @self.on(IterationEvents.MODEL_COMPLETED) def run_post_transform(engine: Engine) -> None: - if isinstance(engine.state.batch, dict) and isinstance(engine.state.output, dict): - # if `batch` and `output` are dictionaries, temporarily combine them for post transforms - data = dict(engine.state.batch) - data.update(engine.state.output) - data = apply_transform(posttrans, data) - for k, v in data.items(): - # split the output data of post transforms into `output` and `batch`, - # `batch` should be read-only, so save the generated key-value into `output` - if k in engine.state.output or k not in engine.state.batch: - engine.state.output[k] = v - else: - engine.state.batch[k] = v - else: - engine.state.output = apply_transform(posttrans, engine.state.output) + engine.state.batch, engine.state.output = engine_apply_transform( + batch=engine.state.batch, + output=engine.state.output, + transform=posttrans, + ) def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None): """ From 04e23fc2974c0421ec2618eea3b5bdd6a21cd569 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Fri, 21 May 2021 03:23:02 +0000 Subject: [PATCH 3/3] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/engines/__init__.py | 2 +- monai/engines/utils.py | 2 +- monai/engines/workflow.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index 57605c39ca..36719ae61c 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -17,6 +17,6 @@ IterationEvents, default_make_latent, default_prepare_batch, - get_devices_spec, engine_apply_transform, + get_devices_spec, ) diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 071f5e70da..988f9e79d4 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 2ee35a46c2..abc4ca269e 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -18,6 +18,7 @@ from monai.engines.utils import IterationEvents, default_prepare_batch from monai.utils import ensure_tuple, exact_version, optional_import + from .utils import engine_apply_transform IgniteEngine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine")