Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions monai/handlers/mlflow_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence

import torch

Expand Down Expand Up @@ -59,6 +59,8 @@ class MLFlowHandler:
global_epoch_transform: a callable that is used to customize global epoch number.
For example, in evaluation, the evaluator engine might want to track synced epoch number
with the trainer engine.
state_attributes: expected attributes from `engine.state`, if provided, will extract them
when epoch completed.
tag_name: when iteration output is a scalar, `tag_name` is used to track, defaults to `'Loss'`.

For more details of MLFlow usage, please refer to: https://mlflow.org/docs/latest/index.html.
Expand All @@ -72,6 +74,7 @@ def __init__(
iteration_logger: Optional[Callable[[Engine], Any]] = None,
output_transform: Callable = lambda x: x[0],
global_epoch_transform: Callable = lambda x: x,
state_attributes: Optional[Sequence[str]] = None,
tag_name: str = DEFAULT_TAG,
) -> None:
if tracking_uri is not None:
Expand All @@ -81,6 +84,7 @@ def __init__(
self.iteration_logger = iteration_logger
self.output_transform = output_transform
self.global_epoch_transform = global_epoch_transform
self.state_attributes = state_attributes
self.tag_name = tag_name

def attach(self, engine: Engine) -> None:
Expand Down Expand Up @@ -144,7 +148,8 @@ def iteration_completed(self, engine: Engine) -> None:
def _default_epoch_log(self, engine: Engine) -> None:
"""
Execute epoch level log operation.
Default to track the values from Ignite `engine.state.metrics` dict.
Default to track the values from Ignite `engine.state.metrics` dict and
track the values of specified attributes of `engine.state`.

Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
Expand All @@ -157,6 +162,10 @@ def _default_epoch_log(self, engine: Engine) -> None:
current_epoch = self.global_epoch_transform(engine.state.epoch)
mlflow.log_metrics(log_dict, step=current_epoch)

if self.state_attributes is not None:
attrs = {attr: getattr(engine.state, attr, None) for attr in self.state_attributes}
mlflow.log_metrics(attrs, step=current_epoch)

def _default_iteration_log(self, engine: Engine) -> None:
"""
Execute iteration log operation based on Ignite `engine.state.output` data.
Expand Down
32 changes: 21 additions & 11 deletions monai/handlers/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import logging
import warnings
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence

import torch

Expand Down Expand Up @@ -47,6 +47,7 @@ def __init__(
iteration_print_logger: Optional[Callable[[Engine], Any]] = None,
output_transform: Callable = lambda x: x[0],
global_epoch_transform: Callable = lambda x: x,
state_attributes: Optional[Sequence[str]] = None,
name: Optional[str] = None,
tag_name: str = DEFAULT_TAG,
key_var_format: str = DEFAULT_KEY_VAL_FORMAT,
Expand All @@ -68,6 +69,8 @@ def __init__(
global_epoch_transform: a callable that is used to customize global epoch number.
For example, in evaluation, the evaluator engine might want to print synced epoch number
with the trainer engine.
state_attributes: expected attributes from `engine.state`, if provided, will extract them
when epoch completed.
name: identifier of logging.logger to use, defaulting to ``engine.logger``.
tag_name: when iteration output is a scalar, tag_name is used to print
tag_name: scalar_value to logger. Defaults to ``'Loss'``.
Expand All @@ -80,6 +83,7 @@ def __init__(
self.iteration_print_logger = iteration_print_logger
self.output_transform = output_transform
self.global_epoch_transform = global_epoch_transform
self.state_attributes = state_attributes
self.logger = logging.getLogger(name)
self._name = name

Expand Down Expand Up @@ -150,22 +154,22 @@ def exception_raised(self, engine: Engine, e: Exception) -> None:
def _default_epoch_print(self, engine: Engine) -> None:
"""
Execute epoch level log operation.
Default to print the values from Ignite `engine.state.metrics` dict.
Default to print the values from Ignite `engine.state.metrics` dict and
print the values of specified attributes of `engine.state`.

Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.

"""
prints_dict = engine.state.metrics
if not prints_dict:
return
current_epoch = self.global_epoch_transform(engine.state.epoch)

out_str = f"Epoch[{current_epoch}] Metrics -- "
for name in sorted(prints_dict):
value = prints_dict[name]
out_str += self.key_var_format.format(name, value)
self.logger.info(out_str)
prints_dict = engine.state.metrics
if prints_dict is not None and len(prints_dict) > 0:
out_str = f"Epoch[{current_epoch}] Metrics -- "
for name in sorted(prints_dict):
value = prints_dict[name]
out_str += self.key_var_format.format(name, value)
self.logger.info(out_str)

if (
hasattr(engine.state, "key_metric_name")
Expand All @@ -175,7 +179,13 @@ def _default_epoch_print(self, engine: Engine) -> None:
out_str = f"Key metric: {engine.state.key_metric_name} " # type: ignore
out_str += f"best value: {engine.state.best_metric} " # type: ignore
out_str += f"at epoch: {engine.state.best_metric_epoch}" # type: ignore
self.logger.info(out_str)
self.logger.info(out_str)

if self.state_attributes is not None and len(self.state_attributes) > 0:
out_str = "State values: "
for attr in self.state_attributes:
out_str += f"{attr}: {getattr(engine.state, attr, None)} "
self.logger.info(out_str)

def _default_iteration_print(self, engine: Engine) -> None:
"""
Expand Down
13 changes: 11 additions & 2 deletions monai/handlers/tensorboard_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

import warnings
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence

import numpy as np
import torch
Expand Down Expand Up @@ -85,6 +85,7 @@ def __init__(
iteration_interval: int = 1,
output_transform: Callable = lambda x: x[0],
global_epoch_transform: Callable = lambda x: x,
state_attributes: Optional[Sequence[str]] = None,
tag_name: str = DEFAULT_TAG,
) -> None:
"""
Expand All @@ -107,6 +108,8 @@ def __init__(
global_epoch_transform: a callable that is used to customize global epoch number.
For example, in evaluation, the evaluator engine might want to use trainer engines epoch number
when plotting epoch vs metric curves.
state_attributes: expected attributes from `engine.state`, if provided, will extract them
when epoch completed.
tag_name: when iteration output is a scalar, tag_name is used to plot, defaults to ``'Loss'``.
"""
super().__init__(summary_writer=summary_writer, log_dir=log_dir)
Expand All @@ -116,6 +119,7 @@ def __init__(
self.iteration_interval = iteration_interval
self.output_transform = output_transform
self.global_epoch_transform = global_epoch_transform
self.state_attributes = state_attributes
self.tag_name = tag_name

def attach(self, engine: Engine) -> None:
Expand Down Expand Up @@ -164,7 +168,8 @@ def iteration_completed(self, engine: Engine) -> None:
def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter) -> None:
"""
Execute epoch level event write operation.
Default to write the values from Ignite `engine.state.metrics` dict.
Default to write the values from Ignite `engine.state.metrics` dict and
write the values of specified attributes of `engine.state`.

Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
Expand All @@ -175,6 +180,10 @@ def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter) -> None:
summary_dict = engine.state.metrics
for name, value in summary_dict.items():
writer.add_scalar(name, value, current_epoch)

if self.state_attributes is not None:
for attr in self.state_attributes:
writer.add_scalar(attr, getattr(engine.state, attr, None), current_epoch)
writer.flush()

def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> None:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_handler_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ def _train_func(engine, batch):
def _update_metric(engine):
current_metric = engine.state.metrics.get("acc", 0.1)
engine.state.metrics["acc"] = current_metric + 0.1
engine.state.test = current_metric

# set up testing handler
test_path = os.path.join(tempdir, "mlflow_test")
handler = MLFlowHandler(tracking_uri=Path(test_path).as_uri())
handler = MLFlowHandler(tracking_uri=Path(test_path).as_uri(), state_attributes=["test"])
handler.attach(engine)
engine.run(range(3), max_epochs=2)
handler.close()
Expand Down
41 changes: 41 additions & 0 deletions tests/test_handler_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,47 @@ def _train_func(engine, batch):
with self.assertRaises(RuntimeError):
engine.run(range(3), max_epochs=2)

def test_attributes_print(self):
log_stream = StringIO()
log_handler = logging.StreamHandler(log_stream)
log_handler.setLevel(logging.INFO)
key_to_handler = "test_logging"

# set up engine
def _train_func(engine, batch):
return [torch.tensor(0.0)]

engine = Engine(_train_func)

# set up dummy metric
@engine.on(Events.EPOCH_COMPLETED)
def _update_metric(engine):
if not hasattr(engine.state, "test1"):
engine.state.test1 = 0.1
engine.state.test2 = 0.2
else:
engine.state.test1 += 0.1
engine.state.test2 += 0.2

# set up testing handler
stats_handler = StatsHandler(
name=key_to_handler,
state_attributes=["test1", "test2", "test3"],
logger_handler=log_handler,
)
stats_handler.attach(engine)

engine.run(range(3), max_epochs=2)

# check logging output
output_str = log_stream.getvalue()
log_handler.close()
grep = re.compile(f".*{key_to_handler}.*")
has_key_word = re.compile(".*State values.*")
for idx, line in enumerate(output_str.split("\n")):
if grep.match(line) and idx in [5, 10]:
self.assertTrue(has_key_word.match(line))


if __name__ == "__main__":
unittest.main()
6 changes: 5 additions & 1 deletion tests/test_handler_tb_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,15 @@ def _train_func(engine, batch):
def _update_metric(engine):
current_metric = engine.state.metrics.get("acc", 0.1)
engine.state.metrics["acc"] = current_metric + 0.1
engine.state.test = current_metric

# set up testing handler
writer = SummaryWriter(log_dir=tempdir)
stats_handler = TensorBoardStatsHandler(
writer, output_transform=lambda x: {"loss": x[0] * 2.0}, global_epoch_transform=lambda x: x * 3.0
summary_writer=writer,
output_transform=lambda x: {"loss": x[0] * 2.0},
global_epoch_transform=lambda x: x * 3.0,
state_attributes=["test"],
)
stats_handler.attach(engine)
engine.run(range(3), max_epochs=2)
Expand Down