Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions monai/visualize/class_activation_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,31 @@
import torch.nn as nn
import torch.nn.functional as F

from monai.config import NdarrayTensor
from monai.transforms import ScaleIntensity
from monai.utils import ensure_tuple, get_torch_version_tuple
from monai.visualize.visualizer import default_upsampler

__all__ = ["CAM", "GradCAM", "GradCAMpp", "ModelWithHooks", "default_normalizer"]


def default_normalizer(x) -> np.ndarray:
def default_normalizer(x: NdarrayTensor) -> NdarrayTensor:
"""
A linear intensity scaling by mapping the (min, max) to (1, 0).
If the input data is PyTorch Tensor, the output data will be Tensor on the same device,
otherwise, output data will be numpy array.

Comment thread
Nic-Ma marked this conversation as resolved.
N.B.: This will flip magnitudes (i.e., smallest will become biggest and vice versa).
Note: This will flip magnitudes (i.e., smallest will become biggest and vice versa).
"""

def _compute(data: np.ndarray) -> np.ndarray:
scaler = ScaleIntensity(minv=1.0, maxv=0.0)
return np.stack([scaler(i) for i in data], axis=0)

if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
scaler = ScaleIntensity(minv=1.0, maxv=0.0)
x = [scaler(x) for x in x]
return np.stack(x, axis=0)
return torch.as_tensor(_compute(x.detach().cpu().numpy()), device=x.device)

return _compute(x)


class ModelWithHooks:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_vis_gradcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_shape(self, input_data, expected_shape):
self.assertTupleEqual(result.shape, expected_shape)
# check result is same whether class_idx=None is used or not
result2 = cam(x=image, layer_idx=-1, class_idx=model(image).max(1)[-1].cpu())
np.testing.assert_array_almost_equal(result, result2)
torch.testing.assert_allclose(result, result2)


if __name__ == "__main__":
Expand Down