diff --git a/docs/source/apps.rst b/docs/source/apps.rst index d81607c6b4..fa92a2bc2d 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -77,3 +77,7 @@ Applications .. automodule:: monai.apps.pathology.utils .. autoclass:: PathologyProbNMS :members: + +.. automodule:: monai.apps.pathology.handlers +.. autoclass:: ProbMapProducer + :members: \ No newline at end of file diff --git a/monai/apps/pathology/__init__.py b/monai/apps/pathology/__init__.py index 591edf1dad..3474a7c10a 100644 --- a/monai/apps/pathology/__init__.py +++ b/monai/apps/pathology/__init__.py @@ -10,4 +10,5 @@ # limitations under the License. from .datasets import MaskedInferenceWSIDataset, PatchWSIDataset, SmartCacheDataset +from .handlers import ProbMapProducer from .utils import ProbNMS diff --git a/monai/apps/pathology/handlers.py b/monai/apps/pathology/handlers.py new file mode 100644 index 0000000000..046e403e0f --- /dev/null +++ b/monai/apps/pathology/handlers.py @@ -0,0 +1,103 @@ +import logging +import os +from typing import TYPE_CHECKING, Dict, Optional + +import numpy as np + +from monai.config import DtypeLike +from monai.utils import exact_version, optional_import + +Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") + + +class ProbMapProducer: + """ + Event handler triggered on completing every iteration to save the probability map + """ + + def __init__( + self, + output_dir: str = "./", + output_postfix: str = "", + dtype: DtypeLike = np.float64, + name: Optional[str] = None, + ) -> None: + """ + Args: + output_dir: output directory to save probability maps. + output_postfix: a string appended to all output file names. + dtype: the data type in which the probability map is stored. Default np.float64. + name: identifier of logging.logger to use, defaulting to `engine.logger`. + + """ + self.logger = logging.getLogger(name) + self._name = name + self.output_dir = output_dir + self.output_postfix = output_postfix + self.dtype = dtype + self.prob_map: Dict[str, np.ndarray] = {} + self.level: Dict[str, int] = {} + self.counter: Dict[str, int] = {} + self.num_done_images: int = 0 + self.num_images: int = 0 + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + + self.num_images = len(engine.data_loader.dataset.data) + + for sample in engine.data_loader.dataset.data: + name = sample["name"] + self.prob_map[name] = np.zeros(sample["mask_shape"], dtype=self.dtype) + self.counter[name] = len(sample["mask_locations"]) + self.level[name] = sample["level"] + + if self._name is None: + self.logger = engine.logger + if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): + engine.add_event_handler(Events.ITERATION_COMPLETED, self) + if not engine.has_event_handler(self.finalize, Events.COMPLETED): + engine.add_event_handler(Events.COMPLETED, self.finalize) + + def __call__(self, engine: Engine) -> None: + """ + This method assumes self.batch_transform will extract metadata from the input batch. + + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + names = engine.state.batch["name"] + locs = engine.state.batch["mask_location"] + pred = engine.state.output["pred"] + for i, name in enumerate(names): + self.prob_map[name][locs[0][i], locs[1][i]] = pred[i] + self.counter[name] -= 1 + if self.counter[name] == 0: + self.save_prob_map(name) + + def save_prob_map(self, name: str) -> None: + """ + This method save the probability map for an image, when its inference is finished, + and delete that probability map from memory. + + Args: + name: the name of image to be saved. + """ + file_path = os.path.join(self.output_dir, name) + np.save(file_path + self.output_postfix + ".npy", self.prob_map[name]) + + self.num_done_images += 1 + self.logger.info(f"Inference of '{name}' is done [{self.num_done_images}/{self.num_images}]!") + del self.prob_map[name] + del self.counter[name] + del self.level[name] + + def finalize(self, engine: Engine): + self.logger.info(f"Probability map is created for {self.num_done_images}/{self.num_images} images!") diff --git a/tests/min_tests.py b/tests/min_tests.py index 83c1ceea9f..e896e81c70 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -43,6 +43,7 @@ def run_testsuit(): "test_handler_confusion_matrix_dist", "test_handler_hausdorff_distance", "test_handler_mean_dice", + "test_handler_prob_map_generator", "test_handler_rocauc", "test_handler_rocauc_dist", "test_handler_segmentation_saver", diff --git a/tests/test_handler_prob_map_generator.py b/tests/test_handler_prob_map_generator.py new file mode 100644 index 0000000000..4882060be9 --- /dev/null +++ b/tests/test_handler_prob_map_generator.py @@ -0,0 +1,95 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import numpy as np +import torch +from ignite.engine import Engine +from parameterized import parameterized +from torch.utils.data import DataLoader + +from monai.apps.pathology.handlers import ProbMapProducer +from monai.data.dataset import Dataset +from monai.engines import Evaluator +from monai.handlers import ValidationHandler + +TEST_CASE_0 = ["image_inference_output_1", 2] +TEST_CASE_1 = ["image_inference_output_2", 9] +TEST_CASE_2 = ["image_inference_output_3", 1000] + + +class TestDataset(Dataset): + def __init__(self, name, size): + self.data = [ + { + "name": name, + "mask_shape": (size, size), + "mask_locations": [[i, i] for i in range(size)], + "level": 0, + } + ] + self.len = size + + def __len__(self): + return self.len + + def __getitem__(self, index): + return { + "name": self.data[0]["name"], + "mask_location": self.data[0]["mask_locations"][index], + "pred": index + 1, + } + + +class TestEvaluator(Evaluator): + def _iteration(self, engine, batchdata): + return batchdata + + +class TestHandlerProbMapGenerator(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + ] + ) + def test_prob_map_generator(self, name, size): + # set up dataset + dataset = TestDataset(name, size) + data_loader = DataLoader(dataset, batch_size=1) + + # set up engine + def inference(enging, batch): + pass + + engine = Engine(inference) + + # add ProbMapGenerator() to evaluator + output_dir = os.path.join(os.path.dirname(__file__), "testing_data") + prob_map_gen = ProbMapProducer(output_dir=output_dir) + + evaluator = TestEvaluator(torch.device("cpu:0"), data_loader, size, val_handlers=[prob_map_gen]) + + # set up validation handler + validation = ValidationHandler(evaluator, interval=1) + validation.attach(engine) + + engine.run(data_loader) + + prob_map = np.load(os.path.join(output_dir, name + ".npy")) + self.assertListEqual(np.diag(prob_map).astype(int).tolist(), list(range(1, size + 1))) + + +if __name__ == "__main__": + unittest.main()