Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/apps.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,7 @@ Applications
.. automodule:: monai.apps.pathology.utils
.. autoclass:: PathologyProbNMS
:members:

.. automodule:: monai.apps.pathology.handlers
.. autoclass:: ProbMapProducer
:members:
1 change: 1 addition & 0 deletions monai/apps/pathology/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
# limitations under the License.

from .datasets import MaskedInferenceWSIDataset, PatchWSIDataset, SmartCacheDataset
from .handlers import ProbMapProducer
from .utils import ProbNMS
103 changes: 103 additions & 0 deletions monai/apps/pathology/handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import logging
import os
from typing import TYPE_CHECKING, Dict, Optional

import numpy as np

from monai.config import DtypeLike
from monai.utils import exact_version, optional_import

Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
if TYPE_CHECKING:
from ignite.engine import Engine
else:
Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine")


class ProbMapProducer:
"""
Event handler triggered on completing every iteration to save the probability map
"""

def __init__(
self,
output_dir: str = "./",
output_postfix: str = "",
dtype: DtypeLike = np.float64,
name: Optional[str] = None,
Comment thread
wyli marked this conversation as resolved.
) -> None:
"""
Args:
output_dir: output directory to save probability maps.
output_postfix: a string appended to all output file names.
dtype: the data type in which the probability map is stored. Default np.float64.
name: identifier of logging.logger to use, defaulting to `engine.logger`.

"""
self.logger = logging.getLogger(name)
self._name = name
self.output_dir = output_dir
self.output_postfix = output_postfix
self.dtype = dtype
self.prob_map: Dict[str, np.ndarray] = {}
self.level: Dict[str, int] = {}
self.counter: Dict[str, int] = {}
self.num_done_images: int = 0
self.num_images: int = 0

def attach(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""

self.num_images = len(engine.data_loader.dataset.data)

for sample in engine.data_loader.dataset.data:
name = sample["name"]
self.prob_map[name] = np.zeros(sample["mask_shape"], dtype=self.dtype)
self.counter[name] = len(sample["mask_locations"])
self.level[name] = sample["level"]

if self._name is None:
self.logger = engine.logger
if not engine.has_event_handler(self, Events.ITERATION_COMPLETED):
engine.add_event_handler(Events.ITERATION_COMPLETED, self)
if not engine.has_event_handler(self.finalize, Events.COMPLETED):
engine.add_event_handler(Events.COMPLETED, self.finalize)

def __call__(self, engine: Engine) -> None:
"""
This method assumes self.batch_transform will extract metadata from the input batch.

Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
names = engine.state.batch["name"]
locs = engine.state.batch["mask_location"]
pred = engine.state.output["pred"]
for i, name in enumerate(names):
self.prob_map[name][locs[0][i], locs[1][i]] = pred[i]
self.counter[name] -= 1
if self.counter[name] == 0:
self.save_prob_map(name)

def save_prob_map(self, name: str) -> None:
"""
This method save the probability map for an image, when its inference is finished,
and delete that probability map from memory.

Args:
name: the name of image to be saved.
"""
file_path = os.path.join(self.output_dir, name)
np.save(file_path + self.output_postfix + ".npy", self.prob_map[name])

self.num_done_images += 1
self.logger.info(f"Inference of '{name}' is done [{self.num_done_images}/{self.num_images}]!")
del self.prob_map[name]
del self.counter[name]
del self.level[name]

def finalize(self, engine: Engine):
self.logger.info(f"Probability map is created for {self.num_done_images}/{self.num_images} images!")
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def run_testsuit():
"test_handler_confusion_matrix_dist",
"test_handler_hausdorff_distance",
"test_handler_mean_dice",
"test_handler_prob_map_generator",
"test_handler_rocauc",
"test_handler_rocauc_dist",
"test_handler_segmentation_saver",
Expand Down
95 changes: 95 additions & 0 deletions tests/test_handler_prob_map_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest

import numpy as np
import torch
from ignite.engine import Engine
from parameterized import parameterized
from torch.utils.data import DataLoader

from monai.apps.pathology.handlers import ProbMapProducer
from monai.data.dataset import Dataset
from monai.engines import Evaluator
from monai.handlers import ValidationHandler

TEST_CASE_0 = ["image_inference_output_1", 2]
TEST_CASE_1 = ["image_inference_output_2", 9]
TEST_CASE_2 = ["image_inference_output_3", 1000]


class TestDataset(Dataset):
def __init__(self, name, size):
self.data = [
{
"name": name,
"mask_shape": (size, size),
"mask_locations": [[i, i] for i in range(size)],
"level": 0,
}
]
self.len = size

def __len__(self):
return self.len

def __getitem__(self, index):
return {
"name": self.data[0]["name"],
"mask_location": self.data[0]["mask_locations"][index],
"pred": index + 1,
}


class TestEvaluator(Evaluator):
def _iteration(self, engine, batchdata):
return batchdata


class TestHandlerProbMapGenerator(unittest.TestCase):
@parameterized.expand(
[
TEST_CASE_0,
TEST_CASE_1,
TEST_CASE_2,
]
)
def test_prob_map_generator(self, name, size):
# set up dataset
dataset = TestDataset(name, size)
data_loader = DataLoader(dataset, batch_size=1)

# set up engine
def inference(enging, batch):
pass

engine = Engine(inference)

# add ProbMapGenerator() to evaluator
output_dir = os.path.join(os.path.dirname(__file__), "testing_data")
prob_map_gen = ProbMapProducer(output_dir=output_dir)

evaluator = TestEvaluator(torch.device("cpu:0"), data_loader, size, val_handlers=[prob_map_gen])

# set up validation handler
validation = ValidationHandler(evaluator, interval=1)
validation.attach(engine)

engine.run(data_loader)

prob_map = np.load(os.path.join(output_dir, name + ".npy"))
self.assertListEqual(np.diag(prob_map).astype(int).tolist(), list(range(1, size + 1)))


if __name__ == "__main__":
unittest.main()