Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion monai/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,11 @@
from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator
from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer
from .trainer import GanTrainer, SupervisedTrainer, Trainer
from .utils import GanKeys, IterationEvents, default_make_latent, default_prepare_batch, get_devices_spec
from .utils import (
GanKeys,
IterationEvents,
default_make_latent,
default_prepare_batch,
engine_apply_transform,
get_devices_spec,
)
28 changes: 27 additions & 1 deletion monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import torch

from monai.transforms import apply_transform
from monai.utils import exact_version, optional_import
from monai.utils.enums import CommonKeys

Expand All @@ -27,6 +28,7 @@
"get_devices_spec",
"default_prepare_batch",
"default_make_latent",
"engine_apply_transform",
]


Expand Down Expand Up @@ -124,3 +126,27 @@ def default_make_latent(
non_blocking: bool = False,
) -> torch.Tensor:
return torch.randn(num_latents, latent_size).to(device=device, non_blocking=non_blocking)


def engine_apply_transform(batch: Any, output: Any, transform: Callable):
"""
Apply transform for the engine.state.batch and engine.state.output.
If `batch` and `output` are dictionaries, temporarily combine them for the transform,
otherwise, apply the transform for `output` data only.

"""
if isinstance(batch, dict) and isinstance(output, dict):
data = dict(batch)
data.update(output)
data = apply_transform(transform, data)
for k, v in data.items():
# split the output data of post transforms into `output` and `batch`,
# `batch` should be read-only, so save the generated key-value into `output`
if k in output or k not in batch:
output[k] = v
else:
batch[k] = v
else:
output = apply_transform(transform, output)

return batch, output
9 changes: 7 additions & 2 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
from torch.utils.data.distributed import DistributedSampler

from monai.engines.utils import IterationEvents, default_prepare_batch
from monai.transforms import apply_transform
from monai.utils import ensure_tuple, exact_version, optional_import

from .utils import engine_apply_transform

IgniteEngine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine")
State, _ = optional_import("ignite.engine", "0.4.4", exact_version, "State")
Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
Expand Down Expand Up @@ -162,7 +163,11 @@ def _register_post_transforms(self, posttrans: Callable):

@self.on(IterationEvents.MODEL_COMPLETED)
def run_post_transform(engine: Engine) -> None:
Comment thread
Nic-Ma marked this conversation as resolved.
engine.state.output = apply_transform(posttrans, engine.state.output)
engine.state.batch, engine.state.output = engine_apply_transform(
batch=engine.state.batch,
output=engine.state.output,
transform=posttrans,
)

def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None):
"""
Expand Down
32 changes: 24 additions & 8 deletions tests/test_integration_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
LoadImaged,
RandCropByPosNegLabeld,
RandRotate90d,
SaveImaged,
ScaleIntensityd,
ToTensord,
)
Expand Down Expand Up @@ -237,13 +238,22 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor
Activationsd(keys="pred", sigmoid=True),
AsDiscreted(keys="pred", threshold_values=True),
KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
# test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch`
SaveImaged(
keys="pred",
meta_keys="image_meta_dict",
output_dir=root_dir,
output_postfix="seg_transform",
save_batch=True,
),
]
)
val_handlers = [
StatsHandler(output_transform=lambda x: None),
CheckpointLoader(load_path=f"{model_file}", load_dict={"net": net}),
SegmentationSaver(
output_dir=root_dir,
output_postfix="seg_handler",
batch_transform=lambda batch: batch["image_meta_dict"],
output_transform=lambda output: output["pred"],
),
Expand Down Expand Up @@ -308,14 +318,20 @@ def train_and_infer(self, idx=0):
self.assertTrue(test_integration_value(TASK, key="infer_metric", data=infer_metric, rtol=1e-2))
results.append(best_metric)
results.append(infer_metric)
output_files = sorted(glob(os.path.join(self.data_dir, "img*", "*.nii.gz")))
for output in output_files:
ave = np.mean(nib.load(output).get_fdata())
results.append(ave)
if idx == 2:
self.assertTrue(test_integration_value(TASK, key="output_sums_2", data=results[2:], rtol=1e-2))
else:
self.assertTrue(test_integration_value(TASK, key="output_sums", data=results[2:], rtol=1e-2))

def _test_saved_files(postfix):
output_files = sorted(glob(os.path.join(self.data_dir, "img*", f"*{postfix}.nii.gz")))
values = []
for output in output_files:
ave = np.mean(nib.load(output).get_fdata())
values.append(ave)
if idx == 2:
self.assertTrue(test_integration_value(TASK, key="output_sums_2", data=values, rtol=1e-2))
else:
self.assertTrue(test_integration_value(TASK, key="output_sums", data=values, rtol=1e-2))

_test_saved_files(postfix="seg_handler")
_test_saved_files(postfix="seg_transform")
try:
os.remove(model_file)
except Exception as e:
Expand Down