Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions monai/handlers/segmentation_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# limitations under the License.

import logging
import warnings
from typing import TYPE_CHECKING, Callable, Optional, Union

import numpy as np
Expand Down Expand Up @@ -121,7 +120,6 @@ def __init__(
squeeze_end_dims=squeeze_end_dims,
data_root_dir=data_root_dir,
)
self.resample = resample
self.batch_transform = batch_transform
self.output_transform = output_transform

Expand Down Expand Up @@ -150,12 +148,15 @@ def __call__(self, engine: Engine) -> None:
engine_output = self.output_transform(engine.state.output)
if isinstance(engine_output, (tuple, list)):
# if a list of data in shape: [channel, H, W, [D]], save every item separately
if self.resample:
warnings.warn("if saving inverted data, please set `resample=False` as it's already resampled.")

self._saver.save_batch = False
for i, d in enumerate(engine_output):
self._saver(d, {k: meta_data[k][i] for k in meta_data} if meta_data is not None else None)
if isinstance(meta_data, dict):
meta_ = {k: meta_data[k][i] for k in meta_data}
elif isinstance(meta_data, (list, tuple)):
meta_ = meta_data[i]
else:
meta_ = meta_data
self._saver(d, meta_)
else:
# if the data is in shape: [batch, channel, H, W, [D]]
self._saver.save_batch = True
Expand Down
21 changes: 15 additions & 6 deletions monai/handlers/transform_inverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ class TransformInverter:
"""
Ignite handler to automatically invert `transforms`.
It takes `engine.state.output` as the input data and uses the transforms information from `engine.state.batch`.
The inverted results are stored in `engine.state.output` with key: "{output_key}_{postfix}".
The inverted data are stored in `engine.state.output` with key: "{output_key}_{postfix}".
And the inverted meta dict will be stored in `engine.state.batch`
with key: "{output_key}_{postfix}_{meta_key_postfix}".

"""

def __init__(
Expand Down Expand Up @@ -136,8 +139,14 @@ def __call__(self, engine: Engine) -> None:
segs_dict[meta_dict_key] = engine.state.batch[meta_dict_key]

with allow_missing_keys_mode(self.transform): # type: ignore
inverted_key = f"{output_key}_{self.postfix}"
engine.state.output[inverted_key] = [
post_func(self._totensor(i[batch_key]).to(device) if to_tensor else i[batch_key])
for i in self.inverter(segs_dict)
]
inverted = self.inverter(segs_dict)

# save the inverted data into state.output
inverted_key = f"{output_key}_{self.postfix}"
engine.state.output[inverted_key] = [
post_func(self._totensor(i[batch_key]).to(device) if to_tensor else i[batch_key]) for i in inverted
]

# save the inverted meta dict into state.batch
if meta_dict_key in engine.state.batch:
engine.state.batch[f"{inverted_key}_{self.meta_key_postfix}"] = [i.get(meta_dict_key) for i in inverted]