diff --git a/docs/source/inferers.rst b/docs/source/inferers.rst index 544e695d2e..e358e603bd 100644 --- a/docs/source/inferers.rst +++ b/docs/source/inferers.rst @@ -30,3 +30,9 @@ Inferers .. autoclass:: SlidingWindowInferer :members: :special-members: __call__ + +`SaliencyInferer` +~~~~~~~~~~~~~~~~~ +.. autoclass:: SaliencyInferer + :members: + :special-members: __call__ diff --git a/monai/inferers/__init__.py b/monai/inferers/__init__.py index 1cdea77b0f..030344728d 100644 --- a/monai/inferers/__init__.py +++ b/monai/inferers/__init__.py @@ -9,5 +9,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .inferer import Inferer, SimpleInferer, SlidingWindowInferer +from .inferer import Inferer, SaliencyInferer, SimpleInferer, SlidingWindowInferer from .utils import sliding_window_inference diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index b17afb4e1d..ecb2c2c178 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -10,14 +10,16 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Callable, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Union import torch +import torch.nn as nn from monai.inferers.utils import sliding_window_inference from monai.utils import BlendMode, PytorchPadMode +from monai.visualize import CAM, GradCAM, GradCAMpp -__all__ = ["Inferer", "SimpleInferer", "SlidingWindowInferer"] +__all__ = ["Inferer", "SimpleInferer", "SlidingWindowInferer", "SaliencyInferer"] class Inferer(ABC): @@ -190,3 +192,54 @@ def __call__( *args, **kwargs, ) + + +class SaliencyInferer(Inferer): + """ + SaliencyInferer is inference with activation maps. + + Args: + cam_name: expected CAM method name, should be: "CAM", "GradCAM" or "GradCAMpp". + target_layers: name of the model layer to generate the feature map. + class_idx: index of the class to be visualized. if None, default to argmax(logits). + args: other optional args to be passed to the `__init__` of cam. + kwargs: other optional keyword args to be passed to `__init__` of cam. + + """ + + def __init__(self, cam_name: str, target_layers: str, class_idx: Optional[int] = None, *args, **kwargs) -> None: + Inferer.__init__(self) + if cam_name.lower() not in ("cam", "gradcam", "gradcampp"): + raise ValueError("cam_name should be: 'CAM', 'GradCAM' or 'GradCAMpp'.") + self.cam_name = cam_name.lower() + self.target_layers = target_layers + self.class_idx = class_idx + self.args = args + self.kwargs = kwargs + + def __call__( # type: ignore + self, + inputs: torch.Tensor, + network: nn.Module, + *args: Any, + **kwargs: Any, + ): + """Unified callable function API of Inferers. + + Args: + inputs: model input data for inference. + network: target model to execute inference. + supports callables such as ``lambda x: my_torch_model(x, additional_config)`` + args: other optional args to be passed to the `__call__` of cam. + kwargs: other optional keyword args to be passed to `__call__` of cam. + + """ + cam: Union[CAM, GradCAM, GradCAMpp] + if self.cam_name == "cam": + cam = CAM(network, self.target_layers, *self.args, **self.kwargs) + elif self.cam_name == "gradcam": + cam = GradCAM(network, self.target_layers, *self.args, **self.kwargs) + else: + cam = GradCAMpp(network, self.target_layers, *self.args, **self.kwargs) + + return cam(inputs, self.class_idx, *args, **kwargs) diff --git a/tests/test_sailency_inferer.py b/tests/test_sailency_inferer.py new file mode 100644 index 0000000000..276bc7816c --- /dev/null +++ b/tests/test_sailency_inferer.py @@ -0,0 +1,52 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.inferers import SaliencyInferer +from monai.networks.nets import DenseNet +from monai.visualize.visualizer import default_upsampler + +TEST_CASE_1 = ["CAM"] + +TEST_CASE_2 = ["GradCAM"] + +TEST_CASE_3 = ["GradCAMpp"] + + +class TestGradientClassActivationMap(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, cam_name): + model = DenseNet( + spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,) + ) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + + image = torch.rand((2, 1, 6, 6, 6), device=device) + target_layer = "class_layers.relu" + fc_layer = "class_layers.out" + if cam_name == "CAM": + inferer = SaliencyInferer(cam_name, target_layer, None, fc_layer, upsampler=default_upsampler) + result = inferer(inputs=image, network=model, layer_idx=-1) + else: + inferer = SaliencyInferer(cam_name, target_layer, None, upsampler=default_upsampler) + result = inferer(image, model, -1, retain_graph=False) + + self.assertTupleEqual(result.shape, (2, 1, 6, 6, 6)) + + +if __name__ == "__main__": + unittest.main()