diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index c63e8e51d9..992eaecdac 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -17,6 +17,7 @@ import torch.nn as nn import torch.nn.functional as F +from monai.config import NdarrayTensor from monai.transforms import ScaleIntensity from monai.utils import ensure_tuple, get_torch_version_tuple from monai.visualize.visualizer import default_upsampler @@ -24,17 +25,23 @@ __all__ = ["CAM", "GradCAM", "GradCAMpp", "ModelWithHooks", "default_normalizer"] -def default_normalizer(x) -> np.ndarray: +def default_normalizer(x: NdarrayTensor) -> NdarrayTensor: """ A linear intensity scaling by mapping the (min, max) to (1, 0). + If the input data is PyTorch Tensor, the output data will be Tensor on the same device, + otherwise, output data will be numpy array. - N.B.: This will flip magnitudes (i.e., smallest will become biggest and vice versa). + Note: This will flip magnitudes (i.e., smallest will become biggest and vice versa). """ + + def _compute(data: np.ndarray) -> np.ndarray: + scaler = ScaleIntensity(minv=1.0, maxv=0.0) + return np.stack([scaler(i) for i in data], axis=0) + if isinstance(x, torch.Tensor): - x = x.detach().cpu().numpy() - scaler = ScaleIntensity(minv=1.0, maxv=0.0) - x = [scaler(x) for x in x] - return np.stack(x, axis=0) + return torch.as_tensor(_compute(x.detach().cpu().numpy()), device=x.device) + + return _compute(x) class ModelWithHooks: diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py index f8e49f486f..eebf32d70b 100644 --- a/tests/test_vis_gradcam.py +++ b/tests/test_vis_gradcam.py @@ -86,7 +86,7 @@ def test_shape(self, input_data, expected_shape): self.assertTupleEqual(result.shape, expected_shape) # check result is same whether class_idx=None is used or not result2 = cam(x=image, layer_idx=-1, class_idx=model(image).max(1)[-1].cpu()) - np.testing.assert_array_almost_equal(result, result2) + torch.testing.assert_allclose(result, result2) if __name__ == "__main__":