Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions monai/visualize/class_activation_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,10 @@ def get_layer(self, layer_id: Union[str, Callable]):
def class_score(self, logits, class_idx):
return logits[:, class_idx].squeeze()

def __call__(self, x, class_idx=None, retain_graph=False):
def __call__(self, x, class_idx=None, retain_graph=False, **kwargs):
train = self.model.training
self.model.eval()
logits = self.model(x)
logits = self.model(x, **kwargs)
self.class_idx = logits.max(1)[-1] if class_idx is None else class_idx
acti, grad = None, None
if self.register_forward:
Expand Down Expand Up @@ -175,17 +175,18 @@ def __init__(
self.upsampler = upsampler
self.postprocessing = postprocessing

def feature_map_size(self, input_size, device="cpu", layer_idx=-1):
def feature_map_size(self, input_size, device="cpu", layer_idx=-1, **kwargs):
"""
Computes the actual feature map size given `nn_module` and the target_layer name.
Args:
input_size: shape of the input tensor
device: the device used to initialise the input tensor
layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.
kwargs: any extra arguments to be passed on to the module as part of its `__call__`.
Returns:
shape of the actual feature map.
"""
return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape
return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx, **kwargs).shape

def compute_map(self, x, class_idx=None, layer_idx=-1):
"""
Expand Down Expand Up @@ -286,8 +287,8 @@ def __init__(
)
self.fc_layers = fc_layers

def compute_map(self, x, class_idx=None, layer_idx=-1):
logits, acti, _ = self.nn_module(x)
def compute_map(self, x, class_idx=None, layer_idx=-1, **kwargs):
logits, acti, _ = self.nn_module(x, **kwargs)
acti = acti[layer_idx]
if class_idx is None:
class_idx = logits.max(1)[-1]
Expand All @@ -298,19 +299,20 @@ def compute_map(self, x, class_idx=None, layer_idx=-1):
output = torch.stack([output[i, b : b + 1] for i, b in enumerate(class_idx)], dim=0)
return output.reshape(b, 1, *spatial) # resume the spatial dims on the selected class

def __call__(self, x, class_idx=None, layer_idx=-1):
def __call__(self, x, class_idx=None, layer_idx=-1, **kwargs):
"""
Compute the activation map with upsampling and postprocessing.

Args:
x: input tensor, shape must be compatible with `nn_module`.
class_idx: index of the class to be visualized. Default to argmax(logits)
layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.
kwargs: any extra arguments to be passed on to the module as part of its `__call__`.

Returns:
activation maps
"""
acti_map = self.compute_map(x, class_idx, layer_idx)
acti_map = self.compute_map(x, class_idx, layer_idx, **kwargs)
return self._upsample_and_post_process(acti_map, x)


Expand Down Expand Up @@ -356,15 +358,15 @@ class GradCAM(CAMBase):

"""

def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1):
_, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph)
def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1, **kwargs):
_, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph, **kwargs)
acti, grad = acti[layer_idx], grad[layer_idx]
b, c, *spatial = grad.shape
weights = grad.view(b, c, -1).mean(2).view(b, c, *[1] * len(spatial))
acti_map = (weights * acti).sum(1, keepdim=True)
return F.relu(acti_map)

def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False):
def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False, **kwargs):
"""
Compute the activation map with upsampling and postprocessing.

Expand All @@ -373,11 +375,12 @@ def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False):
class_idx: index of the class to be visualized. Default to argmax(logits)
layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.
retain_graph: whether to retain_graph for torch module backward call.
kwargs: any extra arguments to be passed on to the module as part of its `__call__`.

Returns:
activation maps
"""
acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph, layer_idx=layer_idx)
acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph, layer_idx=layer_idx, **kwargs)
return self._upsample_and_post_process(acti_map, x)


Expand All @@ -395,8 +398,8 @@ class GradCAMpp(GradCAM):

"""

def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1):
_, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph)
def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1, **kwargs):
_, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph, **kwargs)
acti, grad = acti[layer_idx], grad[layer_idx]
b, c, *spatial = grad.shape
alpha_nr = grad.pow(2)
Expand Down
20 changes: 10 additions & 10 deletions monai/visualize/gradient_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,17 @@ def model(self, m):
else:
self._model = m # replace the ModelWithHooks

def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None, retain_graph=True) -> torch.Tensor:
def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None, retain_graph=True, **kwargs) -> torch.Tensor:
if x.shape[0] != 1:
raise ValueError("expect batch size of 1")
x.requires_grad = True

self._model(x, class_idx=index, retain_graph=retain_graph)
self._model(x, class_idx=index, retain_graph=retain_graph, **kwargs)
grad: torch.Tensor = x.grad.detach()
return grad

def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor:
return self.get_grad(x, index)
def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor:
return self.get_grad(x, index, **kwargs)


class SmoothGrad(VanillaGrad):
Expand All @@ -105,7 +105,7 @@ def __init__(
else:
self.range = range

def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor:
def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor:
stdev = (self.stdev_spread * (x.max() - x.min())).item()
total_gradients = torch.zeros_like(x)
for _ in self.range(self.n_samples):
Expand All @@ -115,7 +115,7 @@ def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) ->
x_plus_noise = x_plus_noise.detach()

# get gradient and accumulate
grad = self.get_grad(x_plus_noise, index)
grad = self.get_grad(x_plus_noise, index, **kwargs)
total_gradients += (grad * grad) if self.magnitude else grad

# average
Expand All @@ -126,12 +126,12 @@ def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) ->


class GuidedBackpropGrad(VanillaGrad):
def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor:
def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor:
with replace_modules_temp(self.model, "relu", _GradReLU(), strict_match=False):
return super().__call__(x, index)
return super().__call__(x, index, **kwargs)


class GuidedBackpropSmoothGrad(SmoothGrad):
def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor:
def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor:
with replace_modules_temp(self.model, "relu", _GradReLU(), strict_match=False):
return super().__call__(x, index)
return super().__call__(x, index, **kwargs)
17 changes: 10 additions & 7 deletions monai/visualize/occlusion_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ def _check_input_bounding_box(b_box, im_shape):
return b_box_min, b_box_max


def _append_to_sensitivity_ims(model, batch_images, sensitivity_ims):
def _append_to_sensitivity_ims(model, batch_images, sensitivity_ims, **kwargs):
"""Infer given images. Append to previous evaluations. Store each class separately."""
batch_images = torch.cat(batch_images, dim=0)
scores = model(batch_images).detach()
scores = model(batch_images, **kwargs).detach()
for i in range(scores.shape[1]):
sensitivity_ims[i] = torch.cat((sensitivity_ims[i], scores[:, i]))
return sensitivity_ims
Expand Down Expand Up @@ -183,14 +183,14 @@ def __init__(
self.per_channel = per_channel
self.verbose = verbose

def _compute_occlusion_sensitivity(self, x, b_box):
def _compute_occlusion_sensitivity(self, x, b_box, **kwargs):

# Get bounding box
im_shape = np.array(x.shape[1:])
b_box_min, b_box_max = _check_input_bounding_box(b_box, im_shape)

# Get the number of prediction classes
num_classes = self.nn_module(x).numel()
num_classes = self.nn_module(x, **kwargs).numel()

# If pad val not supplied, get the mean of the image
pad_val = x.mean() if self.pad_val is None else self.pad_val
Expand Down Expand Up @@ -266,7 +266,7 @@ def _compute_occlusion_sensitivity(self, x, b_box):
# Once the batch is complete (or on last iteration)
if len(batch_images) == self.n_batch or i == num_required_predictions - 1:
# Do the predictions and append to sensitivity maps
sensitivity_ims = _append_to_sensitivity_ims(self.nn_module, batch_images, sensitivity_ims)
sensitivity_ims = _append_to_sensitivity_ims(self.nn_module, batch_images, sensitivity_ims, **kwargs)
# Clear lists
batch_images = []

Expand All @@ -276,7 +276,9 @@ def _compute_occlusion_sensitivity(self, x, b_box):

return sensitivity_ims, output_im_shape

def __call__(self, x: torch.Tensor, b_box: Optional[Sequence] = None) -> Tuple[torch.Tensor, torch.Tensor]:
def __call__(
self, x: torch.Tensor, b_box: Optional[Sequence] = None, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Image to use for inference. Should be a tensor consisting of 1 batch.
Expand All @@ -286,6 +288,7 @@ def __call__(self, x: torch.Tensor, b_box: Optional[Sequence] = None) -> Tuple[t
be useful for larger images.
* Min and max are inclusive, so ``[0, 63, ...]`` will have size ``(64, ...)``.
* Use -ve to use ``min=0`` and ``max=im.shape[x]-1`` for xth dimension.
kwargs: any extra arguments to be passed on to the module as part of its `__call__`.

Returns:
* Occlusion map:
Expand All @@ -305,7 +308,7 @@ def __call__(self, x: torch.Tensor, b_box: Optional[Sequence] = None) -> Tuple[t
_check_input_image(x)

# Generate sensitivity images
sensitivity_ims_list, output_im_shape = self._compute_occlusion_sensitivity(x, b_box)
sensitivity_ims_list, output_im_shape = self._compute_occlusion_sensitivity(x, b_box, **kwargs)

# Loop over image for each classification
for i, sens_i in enumerate(sensitivity_ims_list):
Expand Down
20 changes: 19 additions & 1 deletion tests/test_occlusion_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@
from monai.networks.nets import DenseNet, DenseNet121
from monai.visualize import OcclusionSensitivity


class DenseNetAdjoint(DenseNet121):
def __call__(self, x, adjoint_info):
if adjoint_info != 42:
raise ValueError
return super().__call__(x)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
out_channels_2d = 4
out_channels_3d = 3
Expand All @@ -25,9 +33,12 @@
model_3d = DenseNet(
spatial_dims=3, in_channels=1, out_channels=out_channels_3d, init_features=2, growth_rate=2, block_config=(6,)
).to(device)
model_2d_adjoint = DenseNetAdjoint(spatial_dims=2, in_channels=1, out_channels=out_channels_2d).to(device)
model_2d.eval()
model_2d_2c.eval()
model_3d.eval()
model_2d_adjoint.eval()


# 2D w/ bounding box
TEST_CASE_0 = [
Expand Down Expand Up @@ -59,10 +70,17 @@
(1, 1, 48, 64, out_channels_2d),
(1, 1, 48, 64),
]
# 2D w/ bounding box and adjoint
TEST_CASE_ADJOINT = [
{"nn_module": model_2d_adjoint},
{"x": torch.rand(1, 1, 48, 64).to(device), "b_box": [-1, -1, 2, 40, 1, 62], "adjoint_info": 42},
(1, 1, 39, 62, out_channels_2d),
(1, 1, 39, 62),
]


class TestComputeOcclusionSensitivity(unittest.TestCase):
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_MULTI_CHANNEL])
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_MULTI_CHANNEL, TEST_CASE_ADJOINT])
def test_shape(self, init_data, call_data, map_expected_shape, most_prob_expected_shape):
occ_sens = OcclusionSensitivity(**init_data)
m, most_prob = occ_sens(**call_data)
Expand Down
28 changes: 22 additions & 6 deletions tests/test_vis_gradbased.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,48 @@
from monai.networks.nets import DenseNet, DenseNet121, SEResNet50
from monai.visualize import GuidedBackpropGrad, GuidedBackpropSmoothGrad, SmoothGrad, VanillaGrad


class DenseNetAdjoint(DenseNet121):
def __call__(self, x, adjoint_info):
if adjoint_info != 42:
raise ValueError
return super().__call__(x)


DENSENET2D = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)
DENSENET3D = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,))
SENET2D = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4)
SENET3D = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4)
DENSENET2DADJOINT = DenseNetAdjoint(spatial_dims=2, in_channels=1, out_channels=3)


TESTS = []
for type in (VanillaGrad, SmoothGrad, GuidedBackpropGrad, GuidedBackpropSmoothGrad):
# 2D densenet
TESTS.append([type, DENSENET2D, (1, 1, 48, 64), (1, 1, 48, 64)])
TESTS.append([type, DENSENET2D, (1, 1, 48, 64)])
# 3D densenet
TESTS.append([type, DENSENET3D, (1, 1, 6, 6, 6), (1, 1, 6, 6, 6)])
TESTS.append([type, DENSENET3D, (1, 1, 6, 6, 6)])
# 2D senet
TESTS.append([type, SENET2D, (1, 3, 64, 64), (1, 1, 64, 64)])
TESTS.append([type, SENET2D, (1, 3, 64, 64)])
# 3D senet
TESTS.append([type, SENET3D, (1, 3, 8, 8, 48), (1, 1, 8, 8, 48)])
TESTS.append([type, SENET3D, (1, 3, 8, 8, 48)])
# 2D densenet - adjoint
TESTS.append([type, DENSENET2DADJOINT, (1, 1, 48, 64)])


class TestGradientClassActivationMap(unittest.TestCase):
@parameterized.expand(TESTS)
def test_shape(self, vis_type, model, shape, expected_shape):
def test_shape(self, vis_type, model, shape):
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# optionally test for adjoint info
kwargs = {"adjoint_info": 42} if isinstance(model, DenseNetAdjoint) else {}

model.to(device)
model.eval()
vis = vis_type(model)
x = torch.rand(shape, device=device)
result = vis(x)
result = vis(x, **kwargs)
self.assertTupleEqual(result.shape, x.shape)


Expand Down
Loading