From 486b150efbb4b6270456e107a924da44585305e2 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 3 Jun 2021 12:10:37 +0800 Subject: [PATCH] [DLMED] move IterationEvents into workflow Signed-off-by: Nic Ma --- monai/engines/evaluator.py | 6 ++---- monai/engines/trainer.py | 3 +-- monai/engines/workflow.py | 16 +++++++--------- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index e1fecb745d..71b5069705 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -188,8 +188,7 @@ def __init__( val_handlers=val_handlers, amp=amp, mode=mode, - # add the iteration events - event_names=[IterationEvents] if event_names is None else event_names + [IterationEvents], + event_names=event_names, event_to_attr=event_to_attr, ) @@ -306,8 +305,7 @@ def __init__( val_handlers=val_handlers, amp=amp, mode=mode, - # add the iteration events - event_names=[IterationEvents] if event_names is None else event_names + [IterationEvents], + event_names=event_names, event_to_attr=event_to_attr, ) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index e9e31a1b16..d017bce70d 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -120,8 +120,7 @@ def __init__( additional_metrics=additional_metrics, handlers=train_handlers, amp=amp, - # add the iteration events - event_names=[IterationEvents] if event_names is None else event_names + [IterationEvents], + event_names=event_names, event_to_attr=event_to_attr, ) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index abc4ca269e..9e1516937f 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -137,16 +137,14 @@ def set_sampler_epoch(engine: Engine): self.prepare_batch = prepare_batch self.amp = amp - if event_names is not None: - if not isinstance(event_names, list): + event_names = [IterationEvents] if event_names is None else event_names + [IterationEvents] + for name in event_names: + if isinstance(name, str): + self.register_events(name, event_to_attr=event_to_attr) + elif issubclass(name, EventEnum): + self.register_events(*name, event_to_attr=event_to_attr) + else: raise ValueError("event_names must be a list or string or EventEnum.") - for name in event_names: - if isinstance(name, str): - self.register_events(name, event_to_attr=event_to_attr) - elif issubclass(name, EventEnum): - self.register_events(*name, event_to_attr=event_to_attr) - else: - raise ValueError("event_names must be a list or string or EventEnum.") if post_transform is not None: self._register_post_transforms(post_transform)