From 07702fd6e9a69ab63fb23c4361fd9b37fb904c86 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 29 Sep 2022 11:26:25 +0100 Subject: [PATCH] Visualisation classes to allow kwargs. combine GradCAM and ++ test files Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/visualize/class_activation_maps.py | 31 ++--- monai/visualize/gradient_based.py | 20 ++-- monai/visualize/occlusion_sensitivity.py | 17 +-- tests/test_occlusion_sensitivity.py | 20 +++- tests/test_vis_gradbased.py | 28 ++++- tests/test_vis_gradcam.py | 143 ++++++++++++++++------- tests/test_vis_gradcampp.py | 78 ------------- 7 files changed, 177 insertions(+), 160 deletions(-) delete mode 100644 tests/test_vis_gradcampp.py diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index ba1f5d2589..06999ebf1b 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -125,10 +125,10 @@ def get_layer(self, layer_id: Union[str, Callable]): def class_score(self, logits, class_idx): return logits[:, class_idx].squeeze() - def __call__(self, x, class_idx=None, retain_graph=False): + def __call__(self, x, class_idx=None, retain_graph=False, **kwargs): train = self.model.training self.model.eval() - logits = self.model(x) + logits = self.model(x, **kwargs) self.class_idx = logits.max(1)[-1] if class_idx is None else class_idx acti, grad = None, None if self.register_forward: @@ -175,17 +175,18 @@ def __init__( self.upsampler = upsampler self.postprocessing = postprocessing - def feature_map_size(self, input_size, device="cpu", layer_idx=-1): + def feature_map_size(self, input_size, device="cpu", layer_idx=-1, **kwargs): """ Computes the actual feature map size given `nn_module` and the target_layer name. Args: input_size: shape of the input tensor device: the device used to initialise the input tensor layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. + kwargs: any extra arguments to be passed on to the module as part of its `__call__`. Returns: shape of the actual feature map. """ - return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape + return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx, **kwargs).shape def compute_map(self, x, class_idx=None, layer_idx=-1): """ @@ -286,8 +287,8 @@ def __init__( ) self.fc_layers = fc_layers - def compute_map(self, x, class_idx=None, layer_idx=-1): - logits, acti, _ = self.nn_module(x) + def compute_map(self, x, class_idx=None, layer_idx=-1, **kwargs): + logits, acti, _ = self.nn_module(x, **kwargs) acti = acti[layer_idx] if class_idx is None: class_idx = logits.max(1)[-1] @@ -298,7 +299,7 @@ def compute_map(self, x, class_idx=None, layer_idx=-1): output = torch.stack([output[i, b : b + 1] for i, b in enumerate(class_idx)], dim=0) return output.reshape(b, 1, *spatial) # resume the spatial dims on the selected class - def __call__(self, x, class_idx=None, layer_idx=-1): + def __call__(self, x, class_idx=None, layer_idx=-1, **kwargs): """ Compute the activation map with upsampling and postprocessing. @@ -306,11 +307,12 @@ def __call__(self, x, class_idx=None, layer_idx=-1): x: input tensor, shape must be compatible with `nn_module`. class_idx: index of the class to be visualized. Default to argmax(logits) layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. + kwargs: any extra arguments to be passed on to the module as part of its `__call__`. Returns: activation maps """ - acti_map = self.compute_map(x, class_idx, layer_idx) + acti_map = self.compute_map(x, class_idx, layer_idx, **kwargs) return self._upsample_and_post_process(acti_map, x) @@ -356,15 +358,15 @@ class GradCAM(CAMBase): """ - def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1): - _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph) + def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1, **kwargs): + _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph, **kwargs) acti, grad = acti[layer_idx], grad[layer_idx] b, c, *spatial = grad.shape weights = grad.view(b, c, -1).mean(2).view(b, c, *[1] * len(spatial)) acti_map = (weights * acti).sum(1, keepdim=True) return F.relu(acti_map) - def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False): + def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False, **kwargs): """ Compute the activation map with upsampling and postprocessing. @@ -373,11 +375,12 @@ def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False): class_idx: index of the class to be visualized. Default to argmax(logits) layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. retain_graph: whether to retain_graph for torch module backward call. + kwargs: any extra arguments to be passed on to the module as part of its `__call__`. Returns: activation maps """ - acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph, layer_idx=layer_idx) + acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph, layer_idx=layer_idx, **kwargs) return self._upsample_and_post_process(acti_map, x) @@ -395,8 +398,8 @@ class GradCAMpp(GradCAM): """ - def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1): - _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph) + def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1, **kwargs): + _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph, **kwargs) acti, grad = acti[layer_idx], grad[layer_idx] b, c, *spatial = grad.shape alpha_nr = grad.pow(2) diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py index 32b8110b6d..6727c8c239 100644 --- a/monai/visualize/gradient_based.py +++ b/monai/visualize/gradient_based.py @@ -68,17 +68,17 @@ def model(self, m): else: self._model = m # replace the ModelWithHooks - def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None, retain_graph=True) -> torch.Tensor: + def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None, retain_graph=True, **kwargs) -> torch.Tensor: if x.shape[0] != 1: raise ValueError("expect batch size of 1") x.requires_grad = True - self._model(x, class_idx=index, retain_graph=retain_graph) + self._model(x, class_idx=index, retain_graph=retain_graph, **kwargs) grad: torch.Tensor = x.grad.detach() return grad - def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor: - return self.get_grad(x, index) + def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor: + return self.get_grad(x, index, **kwargs) class SmoothGrad(VanillaGrad): @@ -105,7 +105,7 @@ def __init__( else: self.range = range - def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor: stdev = (self.stdev_spread * (x.max() - x.min())).item() total_gradients = torch.zeros_like(x) for _ in self.range(self.n_samples): @@ -115,7 +115,7 @@ def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> x_plus_noise = x_plus_noise.detach() # get gradient and accumulate - grad = self.get_grad(x_plus_noise, index) + grad = self.get_grad(x_plus_noise, index, **kwargs) total_gradients += (grad * grad) if self.magnitude else grad # average @@ -126,12 +126,12 @@ def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> class GuidedBackpropGrad(VanillaGrad): - def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor: with replace_modules_temp(self.model, "relu", _GradReLU(), strict_match=False): - return super().__call__(x, index) + return super().__call__(x, index, **kwargs) class GuidedBackpropSmoothGrad(SmoothGrad): - def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor: + def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor: with replace_modules_temp(self.model, "relu", _GradReLU(), strict_match=False): - return super().__call__(x, index) + return super().__call__(x, index, **kwargs) diff --git a/monai/visualize/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py index d87b93396a..0630fd0539 100644 --- a/monai/visualize/occlusion_sensitivity.py +++ b/monai/visualize/occlusion_sensitivity.py @@ -68,10 +68,10 @@ def _check_input_bounding_box(b_box, im_shape): return b_box_min, b_box_max -def _append_to_sensitivity_ims(model, batch_images, sensitivity_ims): +def _append_to_sensitivity_ims(model, batch_images, sensitivity_ims, **kwargs): """Infer given images. Append to previous evaluations. Store each class separately.""" batch_images = torch.cat(batch_images, dim=0) - scores = model(batch_images).detach() + scores = model(batch_images, **kwargs).detach() for i in range(scores.shape[1]): sensitivity_ims[i] = torch.cat((sensitivity_ims[i], scores[:, i])) return sensitivity_ims @@ -183,14 +183,14 @@ def __init__( self.per_channel = per_channel self.verbose = verbose - def _compute_occlusion_sensitivity(self, x, b_box): + def _compute_occlusion_sensitivity(self, x, b_box, **kwargs): # Get bounding box im_shape = np.array(x.shape[1:]) b_box_min, b_box_max = _check_input_bounding_box(b_box, im_shape) # Get the number of prediction classes - num_classes = self.nn_module(x).numel() + num_classes = self.nn_module(x, **kwargs).numel() # If pad val not supplied, get the mean of the image pad_val = x.mean() if self.pad_val is None else self.pad_val @@ -266,7 +266,7 @@ def _compute_occlusion_sensitivity(self, x, b_box): # Once the batch is complete (or on last iteration) if len(batch_images) == self.n_batch or i == num_required_predictions - 1: # Do the predictions and append to sensitivity maps - sensitivity_ims = _append_to_sensitivity_ims(self.nn_module, batch_images, sensitivity_ims) + sensitivity_ims = _append_to_sensitivity_ims(self.nn_module, batch_images, sensitivity_ims, **kwargs) # Clear lists batch_images = [] @@ -276,7 +276,9 @@ def _compute_occlusion_sensitivity(self, x, b_box): return sensitivity_ims, output_im_shape - def __call__(self, x: torch.Tensor, b_box: Optional[Sequence] = None) -> Tuple[torch.Tensor, torch.Tensor]: + def __call__( + self, x: torch.Tensor, b_box: Optional[Sequence] = None, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x: Image to use for inference. Should be a tensor consisting of 1 batch. @@ -286,6 +288,7 @@ def __call__(self, x: torch.Tensor, b_box: Optional[Sequence] = None) -> Tuple[t be useful for larger images. * Min and max are inclusive, so ``[0, 63, ...]`` will have size ``(64, ...)``. * Use -ve to use ``min=0`` and ``max=im.shape[x]-1`` for xth dimension. + kwargs: any extra arguments to be passed on to the module as part of its `__call__`. Returns: * Occlusion map: @@ -305,7 +308,7 @@ def __call__(self, x: torch.Tensor, b_box: Optional[Sequence] = None) -> Tuple[t _check_input_image(x) # Generate sensitivity images - sensitivity_ims_list, output_im_shape = self._compute_occlusion_sensitivity(x, b_box) + sensitivity_ims_list, output_im_shape = self._compute_occlusion_sensitivity(x, b_box, **kwargs) # Loop over image for each classification for i, sens_i in enumerate(sensitivity_ims_list): diff --git a/tests/test_occlusion_sensitivity.py b/tests/test_occlusion_sensitivity.py index f258dfc557..ce29b55edf 100644 --- a/tests/test_occlusion_sensitivity.py +++ b/tests/test_occlusion_sensitivity.py @@ -17,6 +17,14 @@ from monai.networks.nets import DenseNet, DenseNet121 from monai.visualize import OcclusionSensitivity + +class DenseNetAdjoint(DenseNet121): + def __call__(self, x, adjoint_info): + if adjoint_info != 42: + raise ValueError + return super().__call__(x) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") out_channels_2d = 4 out_channels_3d = 3 @@ -25,9 +33,12 @@ model_3d = DenseNet( spatial_dims=3, in_channels=1, out_channels=out_channels_3d, init_features=2, growth_rate=2, block_config=(6,) ).to(device) +model_2d_adjoint = DenseNetAdjoint(spatial_dims=2, in_channels=1, out_channels=out_channels_2d).to(device) model_2d.eval() model_2d_2c.eval() model_3d.eval() +model_2d_adjoint.eval() + # 2D w/ bounding box TEST_CASE_0 = [ @@ -59,10 +70,17 @@ (1, 1, 48, 64, out_channels_2d), (1, 1, 48, 64), ] +# 2D w/ bounding box and adjoint +TEST_CASE_ADJOINT = [ + {"nn_module": model_2d_adjoint}, + {"x": torch.rand(1, 1, 48, 64).to(device), "b_box": [-1, -1, 2, 40, 1, 62], "adjoint_info": 42}, + (1, 1, 39, 62, out_channels_2d), + (1, 1, 39, 62), +] class TestComputeOcclusionSensitivity(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_MULTI_CHANNEL]) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_MULTI_CHANNEL, TEST_CASE_ADJOINT]) def test_shape(self, init_data, call_data, map_expected_shape, most_prob_expected_shape): occ_sens = OcclusionSensitivity(**init_data) m, most_prob = occ_sens(**call_data) diff --git a/tests/test_vis_gradbased.py b/tests/test_vis_gradbased.py index 7655ca661e..035cb2967b 100644 --- a/tests/test_vis_gradbased.py +++ b/tests/test_vis_gradbased.py @@ -17,32 +17,48 @@ from monai.networks.nets import DenseNet, DenseNet121, SEResNet50 from monai.visualize import GuidedBackpropGrad, GuidedBackpropSmoothGrad, SmoothGrad, VanillaGrad + +class DenseNetAdjoint(DenseNet121): + def __call__(self, x, adjoint_info): + if adjoint_info != 42: + raise ValueError + return super().__call__(x) + + DENSENET2D = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) DENSENET3D = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,)) SENET2D = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4) SENET3D = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4) +DENSENET2DADJOINT = DenseNetAdjoint(spatial_dims=2, in_channels=1, out_channels=3) + TESTS = [] for type in (VanillaGrad, SmoothGrad, GuidedBackpropGrad, GuidedBackpropSmoothGrad): # 2D densenet - TESTS.append([type, DENSENET2D, (1, 1, 48, 64), (1, 1, 48, 64)]) + TESTS.append([type, DENSENET2D, (1, 1, 48, 64)]) # 3D densenet - TESTS.append([type, DENSENET3D, (1, 1, 6, 6, 6), (1, 1, 6, 6, 6)]) + TESTS.append([type, DENSENET3D, (1, 1, 6, 6, 6)]) # 2D senet - TESTS.append([type, SENET2D, (1, 3, 64, 64), (1, 1, 64, 64)]) + TESTS.append([type, SENET2D, (1, 3, 64, 64)]) # 3D senet - TESTS.append([type, SENET3D, (1, 3, 8, 8, 48), (1, 1, 8, 8, 48)]) + TESTS.append([type, SENET3D, (1, 3, 8, 8, 48)]) + # 2D densenet - adjoint + TESTS.append([type, DENSENET2DADJOINT, (1, 1, 48, 64)]) class TestGradientClassActivationMap(unittest.TestCase): @parameterized.expand(TESTS) - def test_shape(self, vis_type, model, shape, expected_shape): + def test_shape(self, vis_type, model, shape): device = "cuda:0" if torch.cuda.is_available() else "cpu" + + # optionally test for adjoint info + kwargs = {"adjoint_info": 42} if isinstance(model, DenseNetAdjoint) else {} + model.to(device) model.eval() vis = vis_type(model) x = torch.rand(shape, device=device) - result = vis(x) + result = vis(x, **kwargs) self.assertTupleEqual(result.shape, x.shape) diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py index 08a1d8deb0..d81007aa15 100644 --- a/tests/test_vis_gradcam.py +++ b/tests/test_vis_gradcam.py @@ -10,81 +10,136 @@ # limitations under the License. import unittest +from typing import Any, List import numpy as np import torch from parameterized import parameterized from monai.networks.nets import DenseNet, DenseNet121, SEResNet50 -from monai.visualize import GradCAM +from monai.visualize import GradCAM, GradCAMpp from tests.utils import assert_allclose -# 2D -TEST_CASE_0 = [ - { - "model": "densenet2d", - "shape": (2, 1, 48, 64), - "feature_shape": (2, 1, 1, 2), - "target_layers": "class_layers.relu", - }, - (2, 1, 48, 64), -] -# 3D -TEST_CASE_1 = [ - { - "model": "densenet3d", - "shape": (2, 1, 6, 6, 6), - "feature_shape": (2, 1, 2, 2, 2), - "target_layers": "class_layers.relu", - }, - (2, 1, 6, 6, 6), -] -# 2D -TEST_CASE_2 = [ - {"model": "senet2d", "shape": (2, 3, 64, 64), "feature_shape": (2, 1, 2, 2), "target_layers": "layer4"}, - (2, 1, 64, 64), -] - -# 3D -TEST_CASE_3 = [ - {"model": "senet3d", "shape": (2, 3, 8, 8, 48), "feature_shape": (2, 1, 1, 1, 2), "target_layers": "layer4"}, - (2, 1, 8, 8, 48), -] + +class DenseNetAdjoint(DenseNet121): + def __call__(self, x, adjoint_info): + if adjoint_info != 42: + raise ValueError + return super().__call__(x) + + +TESTS: List[Any] = [] +TESTS_ILL: List[Any] = [] + +for cam in (GradCAM, GradCAMpp): + # 2D + TESTS.append( + [ + cam, + { + "model": "densenet2d", + "shape": (2, 1, 48, 64), + "feature_shape": (2, 1, 1, 2), + "target_layers": "class_layers.relu", + }, + (2, 1, 48, 64), + ] + ) + # 3D + TESTS.append( + [ + cam, + { + "model": "densenet3d", + "shape": (2, 1, 6, 6, 6), + "feature_shape": (2, 1, 2, 2, 2), + "target_layers": "class_layers.relu", + }, + (2, 1, 6, 6, 6), + ] + ) + # 2D + TESTS.append( + [ + cam, + {"model": "senet2d", "shape": (2, 3, 64, 64), "feature_shape": (2, 1, 2, 2), "target_layers": "layer4"}, + (2, 1, 64, 64), + ] + ) + + # 3D + TESTS.append( + [ + cam, + { + "model": "senet3d", + "shape": (2, 3, 8, 8, 48), + "feature_shape": (2, 1, 1, 1, 2), + "target_layers": "layer4", + }, + (2, 1, 8, 8, 48), + ] + ) + + # adjoint info + TESTS.append( + [ + cam, + { + "model": "adjoint", + "shape": (2, 1, 48, 64), + "feature_shape": (2, 1, 1, 2), + "target_layers": "class_layers.relu", + }, + (2, 1, 48, 64), + ] + ) + + TESTS_ILL.append([cam]) class TestGradientClassActivationMap(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_data, expected_shape): + @parameterized.expand(TESTS) + def test_shape(self, cam_class, input_data, expected_shape): if input_data["model"] == "densenet2d": model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) - if input_data["model"] == "densenet3d": + elif input_data["model"] == "densenet3d": model = DenseNet( spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,) ) - if input_data["model"] == "senet2d": + elif input_data["model"] == "senet2d": model = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4) - if input_data["model"] == "senet3d": + elif input_data["model"] == "senet3d": model = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4) + elif input_data["model"] == "adjoint": + model = DenseNetAdjoint(spatial_dims=2, in_channels=1, out_channels=3) + + # optionally test for adjoint info + kwargs = {"adjoint_info": 42} if input_data["model"] == "adjoint" else {} + device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() - cam = GradCAM(nn_module=model, target_layers=input_data["target_layers"]) + cam = cam_class(nn_module=model, target_layers=input_data["target_layers"]) image = torch.rand(input_data["shape"], device=device) - result = cam(x=image, layer_idx=-1) - np.testing.assert_array_equal(cam.nn_module.class_idx.cpu(), model(image).max(1)[-1].cpu()) - fea_shape = cam.feature_map_size(input_data["shape"], device=device) + inferred = model(image, **kwargs).max(1)[-1].cpu() + result = cam(x=image, layer_idx=-1, **kwargs) + np.testing.assert_array_equal(cam.nn_module.class_idx.cpu(), inferred) + + fea_shape = cam.feature_map_size(input_data["shape"], device=device, **kwargs) self.assertTupleEqual(fea_shape, input_data["feature_shape"]) self.assertTupleEqual(result.shape, expected_shape) # check result is same whether class_idx=None is used or not - result2 = cam(x=image, layer_idx=-1, class_idx=model(image).max(1)[-1].cpu()) + result2 = cam(x=image, layer_idx=-1, class_idx=inferred, **kwargs) assert_allclose(result, result2) - def test_ill(self): + @parameterized.expand(TESTS_ILL) + def test_ill(self, cam_class): model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) for name, x in model.named_parameters(): if "features" in name: x.requires_grad = False - cam = GradCAM(nn_module=model, target_layers="class_layers.relu") + cam = cam_class(nn_module=model, target_layers="class_layers.relu") image = torch.rand((2, 1, 48, 64)) with self.assertRaises(IndexError): cam(x=image) diff --git a/tests/test_vis_gradcampp.py b/tests/test_vis_gradcampp.py deleted file mode 100644 index a261b6055b..0000000000 --- a/tests/test_vis_gradcampp.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import torch -from parameterized import parameterized - -from monai.networks.nets import DenseNet, DenseNet121, SEResNet50 -from monai.visualize import GradCAMpp - -# 2D -TEST_CASE_0 = [ - { - "model": "densenet2d", - "shape": (2, 1, 48, 64), - "feature_shape": (2, 1, 1, 2), - "target_layers": "class_layers.relu", - }, - (2, 1, 48, 64), -] -# 3D -TEST_CASE_1 = [ - { - "model": "densenet3d", - "shape": (2, 1, 6, 6, 6), - "feature_shape": (2, 1, 2, 2, 2), - "target_layers": "class_layers.relu", - }, - (2, 1, 6, 6, 6), -] -# 2D -TEST_CASE_2 = [ - {"model": "senet2d", "shape": (2, 3, 64, 64), "feature_shape": (2, 1, 2, 2), "target_layers": "layer4"}, - (2, 1, 64, 64), -] - -# 3D -TEST_CASE_3 = [ - {"model": "senet3d", "shape": (2, 3, 8, 8, 48), "feature_shape": (2, 1, 1, 1, 2), "target_layers": "layer4"}, - (2, 1, 8, 8, 48), -] - - -class TestGradientClassActivationMapPP(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_shape(self, input_data, expected_shape): - if input_data["model"] == "densenet2d": - model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) - if input_data["model"] == "densenet3d": - model = DenseNet( - spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,) - ) - if input_data["model"] == "senet2d": - model = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4) - if input_data["model"] == "senet3d": - model = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - cam = GradCAMpp(nn_module=model, target_layers=input_data["target_layers"]) - image = torch.rand(input_data["shape"], device=device) - result = cam(x=image, layer_idx=-1) - fea_shape = cam.feature_map_size(input_data["shape"], device=device) - self.assertTupleEqual(fea_shape, input_data["feature_shape"]) - self.assertTupleEqual(result.shape, expected_shape) - - -if __name__ == "__main__": - unittest.main()