diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index e3e70d8b32..c3e8c456b7 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -14,7 +14,7 @@ import torch from torch.utils.data import DataLoader -from monai.config import IgniteInfo +from monai.config import IgniteInfo, KeysCollection from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch from monai.engines.workflow import Workflow from monai.inferers import Inferer, SimpleInferer @@ -281,6 +281,7 @@ class EnsembleEvaluator(Evaluator): networks: networks to evaluate in order in the evaluator, should be regular PyTorch `torch.nn.Module`. pred_keys: the keys to store every prediction data. the length must exactly match the number of networks. + if None, use "pred_{index}" as key corresponding to N networks, index from `0` to `N-1`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. prepare_batch: function to parse expected data (usually `image`, `label` and other network args) @@ -321,7 +322,7 @@ def __init__( device: torch.device, val_data_loader: Union[Iterable, DataLoader], networks: Sequence[torch.nn.Module], - pred_keys: Sequence[str], + pred_keys: Optional[KeysCollection] = None, epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, @@ -358,7 +359,11 @@ def __init__( ) self.networks = ensure_tuple(networks) - self.pred_keys = ensure_tuple(pred_keys) + self.pred_keys = ( + [f"{Keys.PRED}_{i}" for i in range(len(self.networks))] if pred_keys is None else ensure_tuple(pred_keys) + ) + if len(self.pred_keys) != len(self.networks): + raise ValueError("length of `pred_keys` must be same as the length of `networks`.") self.inferer = SimpleInferer() if inferer is None else inferer def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): diff --git a/tests/test_ensemble_evaluator.py b/tests/test_ensemble_evaluator.py index c7554e9421..dab46f366f 100644 --- a/tests/test_ensemble_evaluator.py +++ b/tests/test_ensemble_evaluator.py @@ -13,12 +13,18 @@ import torch from ignite.engine import EventEnum, Events +from parameterized import parameterized from monai.engines import EnsembleEvaluator +TEST_CASE_1 = [["pred_0", "pred_1", "pred_2", "pred_3", "pred_4"]] + +TEST_CASE_2 = [None] + class TestEnsembleEvaluator(unittest.TestCase): - def test_content(self): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_content(self, pred_keys): device = torch.device("cpu:0") class TestDataset(torch.utils.data.Dataset): @@ -52,7 +58,7 @@ class CustomEvents(EventEnum): device=device, val_data_loader=val_loader, networks=[net0, net1, net2, net3, net4], - pred_keys=["pred0", "pred1", "pred2", "pred3", "pred4"], + pred_keys=pred_keys, event_names=["bwd_event", "opt_event", CustomEvents], event_to_attr={CustomEvents.FOO_EVENT: "foo", "opt_event": "opt"}, ) @@ -61,7 +67,7 @@ class CustomEvents(EventEnum): def run_transform(engine): for i in range(5): expected_value = engine.state.iteration + i - torch.testing.assert_allclose(engine.state.output[0][f"pred{i}"].item(), expected_value) + torch.testing.assert_allclose(engine.state.output[0][f"pred_{i}"].item(), expected_value) @val_engine.on(Events.EPOCH_COMPLETED) def trigger_custom_event():