diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index c3eca3b9c8..c04f9e1701 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -10,7 +10,6 @@ # limitations under the License. import logging -import warnings from typing import TYPE_CHECKING, Callable, Optional, Union import numpy as np @@ -121,7 +120,6 @@ def __init__( squeeze_end_dims=squeeze_end_dims, data_root_dir=data_root_dir, ) - self.resample = resample self.batch_transform = batch_transform self.output_transform = output_transform @@ -150,12 +148,15 @@ def __call__(self, engine: Engine) -> None: engine_output = self.output_transform(engine.state.output) if isinstance(engine_output, (tuple, list)): # if a list of data in shape: [channel, H, W, [D]], save every item separately - if self.resample: - warnings.warn("if saving inverted data, please set `resample=False` as it's already resampled.") - self._saver.save_batch = False for i, d in enumerate(engine_output): - self._saver(d, {k: meta_data[k][i] for k in meta_data} if meta_data is not None else None) + if isinstance(meta_data, dict): + meta_ = {k: meta_data[k][i] for k in meta_data} + elif isinstance(meta_data, (list, tuple)): + meta_ = meta_data[i] + else: + meta_ = meta_data + self._saver(d, meta_) else: # if the data is in shape: [batch, channel, H, W, [D]] self._saver.save_batch = True diff --git a/monai/handlers/transform_inverter.py b/monai/handlers/transform_inverter.py index 5f5c141189..651a47c39c 100644 --- a/monai/handlers/transform_inverter.py +++ b/monai/handlers/transform_inverter.py @@ -33,7 +33,10 @@ class TransformInverter: """ Ignite handler to automatically invert `transforms`. It takes `engine.state.output` as the input data and uses the transforms information from `engine.state.batch`. - The inverted results are stored in `engine.state.output` with key: "{output_key}_{postfix}". + The inverted data are stored in `engine.state.output` with key: "{output_key}_{postfix}". + And the inverted meta dict will be stored in `engine.state.batch` + with key: "{output_key}_{postfix}_{meta_key_postfix}". + """ def __init__( @@ -136,8 +139,14 @@ def __call__(self, engine: Engine) -> None: segs_dict[meta_dict_key] = engine.state.batch[meta_dict_key] with allow_missing_keys_mode(self.transform): # type: ignore - inverted_key = f"{output_key}_{self.postfix}" - engine.state.output[inverted_key] = [ - post_func(self._totensor(i[batch_key]).to(device) if to_tensor else i[batch_key]) - for i in self.inverter(segs_dict) - ] + inverted = self.inverter(segs_dict) + + # save the inverted data into state.output + inverted_key = f"{output_key}_{self.postfix}" + engine.state.output[inverted_key] = [ + post_func(self._totensor(i[batch_key]).to(device) if to_tensor else i[batch_key]) for i in inverted + ] + + # save the inverted meta dict into state.batch + if meta_dict_key in engine.state.batch: + engine.state.batch[f"{inverted_key}_{self.meta_key_postfix}"] = [i.get(meta_dict_key) for i in inverted]