diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py index 97a46cc0b7..b3ee887983 100644 --- a/monai/handlers/mlflow_handler.py +++ b/monai/handlers/mlflow_handler.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence import torch @@ -59,6 +59,8 @@ class MLFlowHandler: global_epoch_transform: a callable that is used to customize global epoch number. For example, in evaluation, the evaluator engine might want to track synced epoch number with the trainer engine. + state_attributes: expected attributes from `engine.state`, if provided, will extract them + when epoch completed. tag_name: when iteration output is a scalar, `tag_name` is used to track, defaults to `'Loss'`. For more details of MLFlow usage, please refer to: https://mlflow.org/docs/latest/index.html. @@ -72,6 +74,7 @@ def __init__( iteration_logger: Optional[Callable[[Engine], Any]] = None, output_transform: Callable = lambda x: x[0], global_epoch_transform: Callable = lambda x: x, + state_attributes: Optional[Sequence[str]] = None, tag_name: str = DEFAULT_TAG, ) -> None: if tracking_uri is not None: @@ -81,6 +84,7 @@ def __init__( self.iteration_logger = iteration_logger self.output_transform = output_transform self.global_epoch_transform = global_epoch_transform + self.state_attributes = state_attributes self.tag_name = tag_name def attach(self, engine: Engine) -> None: @@ -144,7 +148,8 @@ def iteration_completed(self, engine: Engine) -> None: def _default_epoch_log(self, engine: Engine) -> None: """ Execute epoch level log operation. - Default to track the values from Ignite `engine.state.metrics` dict. + Default to track the values from Ignite `engine.state.metrics` dict and + track the values of specified attributes of `engine.state`. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. @@ -157,6 +162,10 @@ def _default_epoch_log(self, engine: Engine) -> None: current_epoch = self.global_epoch_transform(engine.state.epoch) mlflow.log_metrics(log_dict, step=current_epoch) + if self.state_attributes is not None: + attrs = {attr: getattr(engine.state, attr, None) for attr in self.state_attributes} + mlflow.log_metrics(attrs, step=current_epoch) + def _default_iteration_log(self, engine: Engine) -> None: """ Execute iteration log operation based on Ignite `engine.state.output` data. diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index b536ffaebb..f0fcf166e8 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -11,7 +11,7 @@ import logging import warnings -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence import torch @@ -47,6 +47,7 @@ def __init__( iteration_print_logger: Optional[Callable[[Engine], Any]] = None, output_transform: Callable = lambda x: x[0], global_epoch_transform: Callable = lambda x: x, + state_attributes: Optional[Sequence[str]] = None, name: Optional[str] = None, tag_name: str = DEFAULT_TAG, key_var_format: str = DEFAULT_KEY_VAL_FORMAT, @@ -68,6 +69,8 @@ def __init__( global_epoch_transform: a callable that is used to customize global epoch number. For example, in evaluation, the evaluator engine might want to print synced epoch number with the trainer engine. + state_attributes: expected attributes from `engine.state`, if provided, will extract them + when epoch completed. name: identifier of logging.logger to use, defaulting to ``engine.logger``. tag_name: when iteration output is a scalar, tag_name is used to print tag_name: scalar_value to logger. Defaults to ``'Loss'``. @@ -80,6 +83,7 @@ def __init__( self.iteration_print_logger = iteration_print_logger self.output_transform = output_transform self.global_epoch_transform = global_epoch_transform + self.state_attributes = state_attributes self.logger = logging.getLogger(name) self._name = name @@ -150,22 +154,22 @@ def exception_raised(self, engine: Engine, e: Exception) -> None: def _default_epoch_print(self, engine: Engine) -> None: """ Execute epoch level log operation. - Default to print the values from Ignite `engine.state.metrics` dict. + Default to print the values from Ignite `engine.state.metrics` dict and + print the values of specified attributes of `engine.state`. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - prints_dict = engine.state.metrics - if not prints_dict: - return current_epoch = self.global_epoch_transform(engine.state.epoch) - out_str = f"Epoch[{current_epoch}] Metrics -- " - for name in sorted(prints_dict): - value = prints_dict[name] - out_str += self.key_var_format.format(name, value) - self.logger.info(out_str) + prints_dict = engine.state.metrics + if prints_dict is not None and len(prints_dict) > 0: + out_str = f"Epoch[{current_epoch}] Metrics -- " + for name in sorted(prints_dict): + value = prints_dict[name] + out_str += self.key_var_format.format(name, value) + self.logger.info(out_str) if ( hasattr(engine.state, "key_metric_name") @@ -175,7 +179,13 @@ def _default_epoch_print(self, engine: Engine) -> None: out_str = f"Key metric: {engine.state.key_metric_name} " # type: ignore out_str += f"best value: {engine.state.best_metric} " # type: ignore out_str += f"at epoch: {engine.state.best_metric_epoch}" # type: ignore - self.logger.info(out_str) + self.logger.info(out_str) + + if self.state_attributes is not None and len(self.state_attributes) > 0: + out_str = "State values: " + for attr in self.state_attributes: + out_str += f"{attr}: {getattr(engine.state, attr, None)} " + self.logger.info(out_str) def _default_iteration_print(self, engine: Engine) -> None: """ diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index 9d23662ba1..1dd89c7efb 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -10,7 +10,7 @@ # limitations under the License. import warnings -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence import numpy as np import torch @@ -85,6 +85,7 @@ def __init__( iteration_interval: int = 1, output_transform: Callable = lambda x: x[0], global_epoch_transform: Callable = lambda x: x, + state_attributes: Optional[Sequence[str]] = None, tag_name: str = DEFAULT_TAG, ) -> None: """ @@ -107,6 +108,8 @@ def __init__( global_epoch_transform: a callable that is used to customize global epoch number. For example, in evaluation, the evaluator engine might want to use trainer engines epoch number when plotting epoch vs metric curves. + state_attributes: expected attributes from `engine.state`, if provided, will extract them + when epoch completed. tag_name: when iteration output is a scalar, tag_name is used to plot, defaults to ``'Loss'``. """ super().__init__(summary_writer=summary_writer, log_dir=log_dir) @@ -116,6 +119,7 @@ def __init__( self.iteration_interval = iteration_interval self.output_transform = output_transform self.global_epoch_transform = global_epoch_transform + self.state_attributes = state_attributes self.tag_name = tag_name def attach(self, engine: Engine) -> None: @@ -164,7 +168,8 @@ def iteration_completed(self, engine: Engine) -> None: def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter) -> None: """ Execute epoch level event write operation. - Default to write the values from Ignite `engine.state.metrics` dict. + Default to write the values from Ignite `engine.state.metrics` dict and + write the values of specified attributes of `engine.state`. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. @@ -175,6 +180,10 @@ def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter) -> None: summary_dict = engine.state.metrics for name, value in summary_dict.items(): writer.add_scalar(name, value, current_epoch) + + if self.state_attributes is not None: + for attr in self.state_attributes: + writer.add_scalar(attr, getattr(engine.state, attr, None), current_epoch) writer.flush() def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> None: diff --git a/tests/test_handler_mlflow.py b/tests/test_handler_mlflow.py index f210ebfacc..808ebffe33 100644 --- a/tests/test_handler_mlflow.py +++ b/tests/test_handler_mlflow.py @@ -35,10 +35,11 @@ def _train_func(engine, batch): def _update_metric(engine): current_metric = engine.state.metrics.get("acc", 0.1) engine.state.metrics["acc"] = current_metric + 0.1 + engine.state.test = current_metric # set up testing handler test_path = os.path.join(tempdir, "mlflow_test") - handler = MLFlowHandler(tracking_uri=Path(test_path).as_uri()) + handler = MLFlowHandler(tracking_uri=Path(test_path).as_uri(), state_attributes=["test"]) handler.attach(engine) engine.run(range(3), max_epochs=2) handler.close() diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index 9b7ad19dcc..e5d71bcc8a 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -163,6 +163,47 @@ def _train_func(engine, batch): with self.assertRaises(RuntimeError): engine.run(range(3), max_epochs=2) + def test_attributes_print(self): + log_stream = StringIO() + log_handler = logging.StreamHandler(log_stream) + log_handler.setLevel(logging.INFO) + key_to_handler = "test_logging" + + # set up engine + def _train_func(engine, batch): + return [torch.tensor(0.0)] + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + if not hasattr(engine.state, "test1"): + engine.state.test1 = 0.1 + engine.state.test2 = 0.2 + else: + engine.state.test1 += 0.1 + engine.state.test2 += 0.2 + + # set up testing handler + stats_handler = StatsHandler( + name=key_to_handler, + state_attributes=["test1", "test2", "test3"], + logger_handler=log_handler, + ) + stats_handler.attach(engine) + + engine.run(range(3), max_epochs=2) + + # check logging output + output_str = log_stream.getvalue() + log_handler.close() + grep = re.compile(f".*{key_to_handler}.*") + has_key_word = re.compile(".*State values.*") + for idx, line in enumerate(output_str.split("\n")): + if grep.match(line) and idx in [5, 10]: + self.assertTrue(has_key_word.match(line)) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_handler_tb_stats.py b/tests/test_handler_tb_stats.py index 1d722e7f66..f0c4d49fd0 100644 --- a/tests/test_handler_tb_stats.py +++ b/tests/test_handler_tb_stats.py @@ -57,11 +57,15 @@ def _train_func(engine, batch): def _update_metric(engine): current_metric = engine.state.metrics.get("acc", 0.1) engine.state.metrics["acc"] = current_metric + 0.1 + engine.state.test = current_metric # set up testing handler writer = SummaryWriter(log_dir=tempdir) stats_handler = TensorBoardStatsHandler( - writer, output_transform=lambda x: {"loss": x[0] * 2.0}, global_epoch_transform=lambda x: x * 3.0 + summary_writer=writer, + output_transform=lambda x: {"loss": x[0] * 2.0}, + global_epoch_transform=lambda x: x * 3.0, + state_attributes=["test"], ) stats_handler.attach(engine) engine.run(range(3), max_epochs=2)