Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
from torch.utils.data import DataLoader

from monai.config import IgniteInfo
from monai.config import IgniteInfo, KeysCollection
from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
Expand Down Expand Up @@ -281,6 +281,7 @@ class EnsembleEvaluator(Evaluator):
networks: networks to evaluate in order in the evaluator, should be regular PyTorch `torch.nn.Module`.
pred_keys: the keys to store every prediction data.
the length must exactly match the number of networks.
if None, use "pred_{index}" as key corresponding to N networks, index from `0` to `N-1`.
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
Expand Down Expand Up @@ -321,7 +322,7 @@ def __init__(
device: torch.device,
val_data_loader: Union[Iterable, DataLoader],
networks: Sequence[torch.nn.Module],
pred_keys: Sequence[str],
pred_keys: Optional[KeysCollection] = None,
epoch_length: Optional[int] = None,
non_blocking: bool = False,
prepare_batch: Callable = default_prepare_batch,
Expand Down Expand Up @@ -358,7 +359,11 @@ def __init__(
)

self.networks = ensure_tuple(networks)
self.pred_keys = ensure_tuple(pred_keys)
self.pred_keys = (
[f"{Keys.PRED}_{i}" for i in range(len(self.networks))] if pred_keys is None else ensure_tuple(pred_keys)
)
if len(self.pred_keys) != len(self.networks):
raise ValueError("length of `pred_keys` must be same as the length of `networks`.")
self.inferer = SimpleInferer() if inferer is None else inferer

def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
Expand Down
12 changes: 9 additions & 3 deletions tests/test_ensemble_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@

import torch
from ignite.engine import EventEnum, Events
from parameterized import parameterized

from monai.engines import EnsembleEvaluator

TEST_CASE_1 = [["pred_0", "pred_1", "pred_2", "pred_3", "pred_4"]]

TEST_CASE_2 = [None]


class TestEnsembleEvaluator(unittest.TestCase):
def test_content(self):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_content(self, pred_keys):
device = torch.device("cpu:0")

class TestDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -52,7 +58,7 @@ class CustomEvents(EventEnum):
device=device,
val_data_loader=val_loader,
networks=[net0, net1, net2, net3, net4],
pred_keys=["pred0", "pred1", "pred2", "pred3", "pred4"],
pred_keys=pred_keys,
event_names=["bwd_event", "opt_event", CustomEvents],
event_to_attr={CustomEvents.FOO_EVENT: "foo", "opt_event": "opt"},
)
Expand All @@ -61,7 +67,7 @@ class CustomEvents(EventEnum):
def run_transform(engine):
for i in range(5):
expected_value = engine.state.iteration + i
torch.testing.assert_allclose(engine.state.output[0][f"pred{i}"].item(), expected_value)
torch.testing.assert_allclose(engine.state.output[0][f"pred_{i}"].item(), expected_value)

@val_engine.on(Events.EPOCH_COMPLETED)
def trigger_custom_event():
Expand Down