diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index e1fecb745d..71b5069705 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -188,8 +188,7 @@ def __init__( val_handlers=val_handlers, amp=amp, mode=mode, - # add the iteration events - event_names=[IterationEvents] if event_names is None else event_names + [IterationEvents], + event_names=event_names, event_to_attr=event_to_attr, ) @@ -306,8 +305,7 @@ def __init__( val_handlers=val_handlers, amp=amp, mode=mode, - # add the iteration events - event_names=[IterationEvents] if event_names is None else event_names + [IterationEvents], + event_names=event_names, event_to_attr=event_to_attr, ) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index e9e31a1b16..d017bce70d 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -120,8 +120,7 @@ def __init__( additional_metrics=additional_metrics, handlers=train_handlers, amp=amp, - # add the iteration events - event_names=[IterationEvents] if event_names is None else event_names + [IterationEvents], + event_names=event_names, event_to_attr=event_to_attr, ) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index abc4ca269e..9e1516937f 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -137,16 +137,14 @@ def set_sampler_epoch(engine: Engine): self.prepare_batch = prepare_batch self.amp = amp - if event_names is not None: - if not isinstance(event_names, list): + event_names = [IterationEvents] if event_names is None else event_names + [IterationEvents] + for name in event_names: + if isinstance(name, str): + self.register_events(name, event_to_attr=event_to_attr) + elif issubclass(name, EventEnum): + self.register_events(*name, event_to_attr=event_to_attr) + else: raise ValueError("event_names must be a list or string or EventEnum.") - for name in event_names: - if isinstance(name, str): - self.register_events(name, event_to_attr=event_to_attr) - elif issubclass(name, EventEnum): - self.register_events(*name, event_to_attr=event_to_attr) - else: - raise ValueError("event_names must be a list or string or EventEnum.") if post_transform is not None: self._register_post_transforms(post_transform)