diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py index 0cfefb715a..778ec13900 100644 --- a/monai/handlers/metric_logger.py +++ b/monai/handlers/metric_logger.py @@ -40,6 +40,20 @@ class MetricLogger: useful for collecting loss and metric values in one place for storage with checkpoint savers (`state_dict` and `load_state_dict` methods provided as expected by Pytorch and Ignite) and for graphing during training. + Example:: + # construct an evaluator saving mean dice metric values in the key "val_mean_dice" + evaluator = SupervisedEvaluator(..., key_val_metric={"val_mean_dice": MeanDice(...)}) + + # construct the logger and associate with evaluator to extract metric values from + logger = MetricLogger(evaluator=evaluator) + + # construct the trainer with the logger passed in as a handler so that it logs loss values + trainer = SupervisedTrainer(..., train_handlers=[logger, ValidationHandler(evaluator, 1)]) + + # run training, logger.loss will be a list of (iteration, loss) values, logger.metrics a dict with key + # "val_mean_dice" storing a list of (iteration, metric) values + trainer.run() + Args: loss_transform: Converts the `output` value from the trainer's state into a loss value metric_transform: Converts the metric value coming from the trainer/evaluator's state into a storable value diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index 726a11731c..10dfe59f59 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -21,22 +21,6 @@ import numpy as np import torch -# from monai.utils import exact_version, optional_import - -# if TYPE_CHECKING: -# import matplotlib.pyplot as plt -# from ignite.engine import Engine, Events - -# Figure = plt.Figure -# Axes = plt.Axes -# has_matplotlib = True -# else: -# Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") -# Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") -# plt, has_matplotlib = optional_import("matplotlib.pyplot") -# Figure, _ = optional_import("matplotlib.pyplot", name="Figure") -# Axes, _ = optional_import("matplotlib.pyplot", name="Axes") - try: import matplotlib.pyplot as plt diff --git a/tests/test_handler_metric_logger.py b/tests/test_handler_metric_logger.py new file mode 100644 index 0000000000..5812605cd7 --- /dev/null +++ b/tests/test_handler_metric_logger.py @@ -0,0 +1,60 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.utils import optional_import +from tests.utils import SkipIfNoModule + +try: + _, has_ignite = optional_import("ignite") + from ignite.engine import Engine, Events + + from monai.handlers import MetricLogger +except ImportError: + has_ignite = False + + +class TestHandlerMetricLogger(unittest.TestCase): + @SkipIfNoModule("ignite") + def test_metric_logging(self): + dummy_name = "dummy" + + # set up engine + def _train_func(engine, batch): + return torch.tensor(0.0) + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + engine.state.metrics[dummy_name] = 1 + + # set up testing handler + handler = MetricLogger(loss_transform=lambda output: output.item()) + handler.attach(engine) + + engine.run(range(3), max_epochs=2) + + expected_loss = [(1, 0.0), (2, 0.0), (3, 0.0), (4, 0.0), (5, 0.0), (6, 0.0)] + expected_metric = [(4, 1), (5, 1), (6, 1)] + + self.assertSetEqual({dummy_name}, set(handler.metrics)) + + self.assertListEqual(expected_loss, handler.loss) + self.assertListEqual(expected_metric, handler.metrics[dummy_name]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_threadcontainer.py b/tests/test_threadcontainer.py index 13608e166c..75612586e8 100644 --- a/tests/test_threadcontainer.py +++ b/tests/test_threadcontainer.py @@ -9,27 +9,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import tempfile import time import unittest import torch -from monai.utils import optional_import +from monai.data import DataLoader +from monai.utils import optional_import, set_determinism from monai.utils.enums import CommonKeys +from tests.utils import SkipIfNoModule try: _, has_ignite = optional_import("ignite") from monai.engines import SupervisedTrainer + from monai.handlers import MetricLogger from monai.utils import ThreadContainer except ImportError: has_ignite = False -from monai.data import DataLoader +compare_images, _ = optional_import("matplotlib.testing.compare", name="compare_images") class TestThreadContainer(unittest.TestCase): - @unittest.skipIf(not has_ignite, "Ignite needed for this test") + @SkipIfNoModule("ignite") def test_container(self): net = torch.nn.Conv2d(1, 1, 3, padding=1) @@ -57,3 +62,43 @@ def test_container(self): self.assertTrue(len(con.status_dict) > 0) con.join() + + @SkipIfNoModule("ignite") + @SkipIfNoModule("matplotlib") + def test_plot(self): + set_determinism(0) + + testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") + + net = torch.nn.Conv2d(1, 1, 3, padding=1) + + opt = torch.optim.Adam(net.parameters()) + + img = torch.rand(1, 16, 16) + data = {CommonKeys.IMAGE: img, CommonKeys.LABEL: img} + loader = DataLoader([data for _ in range(10)]) + + trainer = SupervisedTrainer( + device=torch.device("cpu"), + max_epochs=1, + train_data_loader=loader, + network=net, + optimizer=opt, + loss_function=torch.nn.L1Loss(), + ) + + logger = MetricLogger() + logger.attach(trainer) + + con = ThreadContainer(trainer) + con.start() + con.join() + + fig = con.plot_status(logger) + + with tempfile.TemporaryDirectory() as tempdir: + tempimg = f"{tempdir}/threadcontainer_plot_test.png" + fig.savefig(tempimg) + comp = compare_images(tempimg, f"{testing_dir}/threadcontainer_plot_test.png", 1e-3) + + self.assertIsNone(comp, comp) # None indicates test passed diff --git a/tests/testing_data/threadcontainer_plot_test.png b/tests/testing_data/threadcontainer_plot_test.png new file mode 100644 index 0000000000..b3576491ec Binary files /dev/null and b/tests/testing_data/threadcontainer_plot_test.png differ