Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions monai/visualize/class_activation_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.utils import eval_mode, train_mode
from monai.transforms import ScaleIntensity
from monai.utils import ensure_tuple
from monai.visualize.visualizer import default_upsampler
Expand Down Expand Up @@ -110,26 +109,24 @@ def get_layer(self, layer_id: Union[str, Callable]):
return mod
raise NotImplementedError(f"Could not find {layer_id}.")

def class_score(self, logits, class_idx=None):
if class_idx is not None:
return logits[:, class_idx].squeeze(), class_idx
class_idx = logits.max(1)[-1]
return logits[:, class_idx].squeeze(), class_idx
def class_score(self, logits, class_idx):
return logits[:, class_idx].squeeze()

def __call__(self, x, class_idx=None, retain_graph=False):
# Use train_mode if grad is required, else eval_mode
mode = train_mode if self.register_backward else eval_mode
with mode(self.model):
logits = self.model(x)
acti, grad = None, None
if self.register_forward:
acti = tuple(self.activations[layer] for layer in self.target_layers)
if self.register_backward:
score, class_idx = self.class_score(logits, class_idx)
self.model.zero_grad()
self.score, self.class_idx = score, class_idx
score.sum().backward(retain_graph=retain_graph)
grad = tuple(self.gradients[layer] for layer in self.target_layers)
train = self.model.training
self.model.eval()
logits = self.model(x)
self.class_idx = logits.max(1)[-1] if class_idx is None else class_idx
acti, grad = None, None
if self.register_forward:
acti = tuple(self.activations[layer] for layer in self.target_layers)
if self.register_backward:
self.score = self.class_score(logits, self.class_idx)
self.model.zero_grad()
self.score.sum().backward(retain_graph=retain_graph)
grad = tuple(self.gradients[layer] for layer in self.target_layers)
if train:
self.model.train()
return logits, acti, grad

def get_wrapped_net(self):
Expand Down
5 changes: 5 additions & 0 deletions tests/test_vis_gradcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import unittest

import numpy as np
import torch
from parameterized import parameterized

Expand Down Expand Up @@ -79,9 +80,13 @@ def test_shape(self, input_data, expected_shape):
cam = GradCAM(nn_module=model, target_layers=input_data["target_layers"])
image = torch.rand(input_data["shape"], device=device)
result = cam(x=image, layer_idx=-1)
np.testing.assert_array_equal(cam.nn_module.class_idx.cpu(), model(image).max(1)[-1].cpu())
fea_shape = cam.feature_map_size(input_data["shape"], device=device)
self.assertTupleEqual(fea_shape, input_data["feature_shape"])
self.assertTupleEqual(result.shape, expected_shape)
# check result is same whether class_idx=None is used or not
result2 = cam(x=image, layer_idx=-1, class_idx=model(image).max(1)[-1].cpu())
np.testing.assert_array_almost_equal(result, result2)


if __name__ == "__main__":
Expand Down