diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index d3a14f6104..36719ae61c 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -12,4 +12,11 @@ from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer from .trainer import GanTrainer, SupervisedTrainer, Trainer -from .utils import GanKeys, IterationEvents, default_make_latent, default_prepare_batch, get_devices_spec +from .utils import ( + GanKeys, + IterationEvents, + default_make_latent, + default_prepare_batch, + engine_apply_transform, + get_devices_spec, +) diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 265a63ee0c..988f9e79d4 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -9,10 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch +from monai.transforms import apply_transform from monai.utils import exact_version, optional_import from monai.utils.enums import CommonKeys @@ -27,6 +28,7 @@ "get_devices_spec", "default_prepare_batch", "default_make_latent", + "engine_apply_transform", ] @@ -124,3 +126,27 @@ def default_make_latent( non_blocking: bool = False, ) -> torch.Tensor: return torch.randn(num_latents, latent_size).to(device=device, non_blocking=non_blocking) + + +def engine_apply_transform(batch: Any, output: Any, transform: Callable): + """ + Apply transform for the engine.state.batch and engine.state.output. + If `batch` and `output` are dictionaries, temporarily combine them for the transform, + otherwise, apply the transform for `output` data only. + + """ + if isinstance(batch, dict) and isinstance(output, dict): + data = dict(batch) + data.update(output) + data = apply_transform(transform, data) + for k, v in data.items(): + # split the output data of post transforms into `output` and `batch`, + # `batch` should be read-only, so save the generated key-value into `output` + if k in output or k not in batch: + output[k] = v + else: + batch[k] = v + else: + output = apply_transform(transform, output) + + return batch, output diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 4018dabc40..abc4ca269e 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -17,9 +17,10 @@ from torch.utils.data.distributed import DistributedSampler from monai.engines.utils import IterationEvents, default_prepare_batch -from monai.transforms import apply_transform from monai.utils import ensure_tuple, exact_version, optional_import +from .utils import engine_apply_transform + IgniteEngine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") State, _ = optional_import("ignite.engine", "0.4.4", exact_version, "State") Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") @@ -162,7 +163,11 @@ def _register_post_transforms(self, posttrans: Callable): @self.on(IterationEvents.MODEL_COMPLETED) def run_post_transform(engine: Engine) -> None: - engine.state.output = apply_transform(posttrans, engine.state.output) + engine.state.batch, engine.state.output = engine_apply_transform( + batch=engine.state.batch, + output=engine.state.output, + transform=posttrans, + ) def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None): """ diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 00d097b2b6..2184c29b99 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -48,6 +48,7 @@ LoadImaged, RandCropByPosNegLabeld, RandRotate90d, + SaveImaged, ScaleIntensityd, ToTensord, ) @@ -237,6 +238,14 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), + # test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch` + SaveImaged( + keys="pred", + meta_keys="image_meta_dict", + output_dir=root_dir, + output_postfix="seg_transform", + save_batch=True, + ), ] ) val_handlers = [ @@ -244,6 +253,7 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}), SegmentationSaver( output_dir=root_dir, + output_postfix="seg_handler", batch_transform=lambda batch: batch["image_meta_dict"], output_transform=lambda output: output["pred"], ), @@ -308,14 +318,20 @@ def train_and_infer(self, idx=0): self.assertTrue(test_integration_value(TASK, key="infer_metric", data=infer_metric, rtol=1e-2)) results.append(best_metric) results.append(infer_metric) - output_files = sorted(glob(os.path.join(self.data_dir, "img*", "*.nii.gz"))) - for output in output_files: - ave = np.mean(nib.load(output).get_fdata()) - results.append(ave) - if idx == 2: - self.assertTrue(test_integration_value(TASK, key="output_sums_2", data=results[2:], rtol=1e-2)) - else: - self.assertTrue(test_integration_value(TASK, key="output_sums", data=results[2:], rtol=1e-2)) + + def _test_saved_files(postfix): + output_files = sorted(glob(os.path.join(self.data_dir, "img*", f"*{postfix}.nii.gz"))) + values = [] + for output in output_files: + ave = np.mean(nib.load(output).get_fdata()) + values.append(ave) + if idx == 2: + self.assertTrue(test_integration_value(TASK, key="output_sums_2", data=values, rtol=1e-2)) + else: + self.assertTrue(test_integration_value(TASK, key="output_sums", data=values, rtol=1e-2)) + + _test_saved_files(postfix="seg_handler") + _test_saved_files(postfix="seg_transform") try: os.remove(model_file) except Exception as e: