Gradient Class-Activation Maps (Grad-CAM)#

In this notebook we will use the Grad-CAM algorithm to visualize which regions in an image dominate the decision for a specific class in a classification neural network. We visualize where the network is looking.

import torchvision.models
from torchvision.models import resnet50
from torchvision import transforms
from skimage.transform import resize
import torch
from skimage.io import imread
import stackview
import numpy as np
from functools import partial
from IPython.display import display, Markdown
import matplotlib.pyplot as plt
from utilities import visualize_image_list
from torchvision.models import ResNet50_Weights as W

Loading the model#

We use the ResNet architecture and more specifically ResNet50, a pretrained model for classifying images.

resnet_model = resnet50(weights=W, progress=False)
model = resnet_model.eval()
C:\Users\rober\miniforge3\envs\genai-gpu\Lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)

It was trained on ImageNet. For academic purposes, we can print out some class names available in ImageNet.

classes = W.DEFAULT.meta["categories"]
classes[:5]
['tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead']
len(classes)
1000

Here we use a dictionay as lookup-table to get the index of a specific class name. We will later use this to determine weights for specific classes.

class_to_idx = {cls: idx for (idx, cls) in enumerate(classes)}
class_to_idx["flagpole"]
557

Example image data for classification#

In the following we will us a cropped image licensed CC-BY-SA by HTW Dresden / Peter Sebb (Source).

original_image = imread("data/htw-front-cc-by-sa.png")[...,:3]

stackview.insight(original_image)
shape(400, 400, 3)
dtypeuint8
size468.8 kB
min0
max255

Next, we convert this image to a Pytorch tensor, which is required to process it by the neural network.

input_tensor = transforms.ToTensor()(original_image).unsqueeze(0)
input_tensor.shape
torch.Size([1, 3, 400, 400])

Storing feature images and gradients#

In the following code, we register some callback function in the network to be able to access intermediate results of the algorithm such as feature images and gradients. This will slow down processing a bit and is not recommended in production.

layers = [model.layer1, model.layer2, model.layer3, model.layer4]

features = [None] * len(layers)
gradients = [None] * len(layers)

def save_feature_maps(i, module, inp, out):
    features[i] = out

def save_gradients(i, module, inp, out):
    gradients[i] = out[0]

for i, layer in enumerate(layers):
    layer.register_forward_hook(partial(save_feature_maps, i))
    layer.register_full_backward_hook(partial(save_gradients, i))

Prediction step#

To visualize what the network is doing, we need to execute it. After this step, the classification result is availale.

output = model(input_tensor)

output.shape
torch.Size([1, 1000])

The class index and class name of the classification is:

class_idx = output.argmax(axis=1).detach()
class_idx, classes[class_idx]
(tensor([436]), 'beach wagon')

Inspecting the intermediate results#

As we stored intermediate results (feature images), we can now visualize them.

features[0].shape
torch.Size([1, 256, 100, 100])
for i, layer in enumerate(layers):

    layer_features = features[i][0].detach().numpy()
    
    display(Markdown(f"### Layer {i+1} {layer_features.shape}"))

    num_figs = 5
    f, a = plt.subplots(1,num_figs, figsize=(10,2))
    for f in range(num_figs):
        stackview.imshow(layer_features[f], plot=a[f])
    plt.show()
    
    #display(stackview.insight())

Layer 1 (256, 100, 100)

../_images/2d245f6abd44b3093cb11b8b20e59d38523085614be1ef60fe286442c6af3a3d.png

Layer 2 (512, 50, 50)

../_images/79c4417e32936bfb3a7af08d212ddba4cd22b63194cab7c20e07f87afed5b57d.png

Layer 3 (1024, 25, 25)

../_images/c5bc4295eaf198ef60ee5587199d8b76ef618caba22fa29050db332aa513460b.png

Layer 4 (2048, 13, 13)

../_images/dfa12ae6b1d4021e53f7bdfb1c53f0b2bae5499b55ec8977d09350d2060d58e4.png

The deeper the layer (higher layer number), the less interpretable the images are.

Determining gradients#

To determine the gradients, we use a single back-propagation step using the class we just determined. This is like we would do during training to improve classification quality for this one specific class given this one specific input image.

model.zero_grad()
one_hot = torch.zeros_like(output)
one_hot[0][class_idx[0]] = 1
output.backward(gradient=one_hot)

We can summarize these gradients to one weight-adaption number for each feature image in the last convolutional layer. The higher this number, the more relevant is the specific feature image for making the classification for this one specific class.

weight_adaption = torch.mean(gradients[-1], dim=(2, 3))[0]
weight_adaption.shape
torch.Size([2048])
weight_adaption[:3]
tensor([-3.8505e-05, -2.3461e-04, -6.0150e-05])
weight_adaption.max()
tensor(0.0012)

Summarizing feature images#

For visualization purposes we just show the first feature images of the last/deepest convolutional layer again. These images will be multiplied with the weight-adaptions explained above.

num_features = 4
images = []
image_names = []

for i, f in enumerate(features[-1][0][:num_features]):
    images.append(f.detach().cpu().numpy())
    image_names.append(f"Feature {i+1}")
    
visualize_image_list(images, image_names)
../_images/80678a44ec43092263e4a2574a9dc36021f5bb767c6db3828c2c27e49fa2a726.png
# Create CAM
cam = torch.zeros(features[-1].shape[2:], dtype=torch.float32)

# Multiply weights with feature maps and sum
for i, w in enumerate(weight_adaption):
    cam += w * features[-1][0][i]

After multiplication, these images will be summarized into a single image that shows us where in the specific image the network is looking when checking the speicified class. This is a class-activation map.

projected_cam = torch.maximum(cam, torch.tensor(0)).detach().cpu().numpy()

stackview.insight(projected_cam)
shape(13, 13)
dtypefloat32
size676.0 B
min0.0
max0.16129449

Overlay#

To visualize this map on top of the original image, we create an upsampled, interpolated image of it.

upsampled_cam = resize(projected_cam, (original_image.shape[0], original_image.shape[1]))

stackview.insight(upsampled_cam)
shape(400, 400)
dtypefloat32
size625.0 kB
min0.0
max0.16114752

These two images can be visualized on top of each other.

stackview.imshow(original_image, continue_drawing=True)
stackview.imshow(upsampled_cam, colormap='viridis', alpha=0.6)
../_images/c5c4d5512775f39d6e48499bcca8c23cdd7aff92b467075082485af7d196c563.png

Class activation maps for different classes#

We can compute this visualization also for other classes. To simplify this, we write a Python helper function, which does the same as above.

def show_cam_for_class(class_name):
    class_idx = [class_to_idx[class_name]]
    output = model(input_tensor)
    
    model.zero_grad()
    one_hot = torch.zeros_like(output)
    one_hot[0][class_idx[0]] = 1
    output.backward(gradient=one_hot)
    
    feature_weights = torch.mean(gradients[-1], dim=(2, 3))[0]
    
    # Create CAM
    cam = torch.zeros(features[-1].shape[2:], dtype=torch.float32)
    
    # Multiply weights with feature maps and sum
    for i, w in enumerate(feature_weights):
        cam += w * features[-1][0][i]
    
    projected_cam = torch.maximum(cam, torch.tensor(0)).detach().cpu().numpy()
    
    upsampled_cam = resize(projected_cam, (original_image.shape[0], original_image.shape[1]))
    
    stackview.imshow(original_image, continue_drawing=True)
    stackview.imshow(upsampled_cam, colormap='viridis', alpha=0.6)

show_cam_for_class("palace")
../_images/0529aa28489e7597e63e59cd29e9839c13e9550bc01b2a77077fab7f4bdf30c1.png
show_cam_for_class("flagpole")
../_images/f644ab7d83edf5f33cc8eeaa5819dec6a6e124b7d6b18ac2497e682c1defe22c.png
show_cam_for_class("great white shark")
../_images/7bc2adb2765178f3f27d13760341abc713c109554ad316dfadea9f14c4228fc1.png

Exercise#

What needs to be changed above to make sure the classification returns “car”?

Exercise#

Write a Python function that takes an image filename as parameter and returns the class name as string and a corresponding CAM image. Call this function in a loop which iterates over all images in the folder ‘data’.