DEV Community

CodeTrade India Pvt. Ltd.
CodeTrade India Pvt. Ltd.

Posted on

How Does Grad-CAM Work in PyTorch?

Grad-CAM is a visualization technique that provides visual explanations for decisions from convolutional neural networks (CNNs). It produces course localization maps that highlight important regions in the input image for predicting a particular class.

How Grad-CAM Works in PyTorch

Implementation of Grad-CAM in PyTorch involves several steps, each step is crucial for creating accurate and meaningful visual explanations.

Step 1: Preprocess the Input Image

The first step is to preprocess the input image to make it suitable for the neural network model. This involves resizing the image, normalizing it, and converting it into a tensor format.

The image preprocessing ensures that the image meets the input requirements of the model and improves the accuracy of the GradCAM visualization.

from torchvision import transforms
import cv2

# Define the preprocessing transformation
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load and preprocess the image
img = cv2.imread('path_to_image.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_tensor = preprocess(img).unsqueeze(0)
Enter fullscreen mode Exit fullscreen mode

Step 2: Perform a Forward Pass

Perform a forward pass through the model to obtain the predictions. This step passes the preprocessed image through the network to get the logits or output scores for each class.

# Perform the forward pass
model.eval() # Set the model to evaluation mode
output = model(img_tensor)
pred_class = output.argmax(dim=1).item()
Enter fullscreen mode Exit fullscreen mode

Step 3: Identify the Target Layer

Grad-CAM requires access to the activations of a convolutional layer and the gradients of the target class to those activations. Typically, the last convolutional layer is used as it captures the most detailed spatial information. We register hooks to capture these activations and gradients during the forward and backward passes.

# Identify the target layer
target_layer = model.layer4[-1]

# Lists to store activations and gradients
activations = []
gradients = []

# Hooks to capture activations and gradients
def forward_hook(module, input, output):
    activations.append(output)

def backward_hook(module, grad_input, grad_output):
    gradients.append(grad_output[0])

target_layer.register_forward_hook(forward_hook)
target_layer.register_full_backward_hook(backward_hook)

Enter fullscreen mode Exit fullscreen mode

4. Backward Pass

After performing the forward pass, a backward pass is done to compute the gradients of the target class to the activations of the target layer. This step helps in understanding which parts of the image are important for the model prediction.

# Zero the gradients
model.zero_grad()

# Backward pass to compute gradients
output[:, pred_class].backward()
Enter fullscreen mode Exit fullscreen mode

5. Compute the Heatmap

Using the captured gradients and activations, compute the Grad-CAM heatmap. The heatmap is calculated by weighting the activations by the average gradient and applying a ReLU activation to remove negative values. The heatmap highlights the regions in the image that are important for the prediction.

import numpy as np

# Compute the weights
weights = torch.mean(gradients[0], dim=[2, 3])

# Compute the Grad-CAM heatmap
heatmap = torch.sum(weights * activations[0], dim=1).squeeze()
heatmap = np.maximum(heatmap.cpu().detach().numpy(), 0)
heatmap /= np.max(heatmap)
Enter fullscreen mode Exit fullscreen mode

6. Visualize the Heatmap

The final step is to overlay the computed heatmap on the original image. This visualization helps in understanding which regions of the image contributed most to the model’s decision.

import cv2

# Resize the heatmap to match the original image size
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))

# Convert heatmap to RGB format and apply colormap
heatmap = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET)

# Overlay the heatmap on the original image
superimposed_img = cv2.addWeighted(img, 0.6, heatmap, 0.4, 0)

# Display the result
cv2.imshow('Grad-CAM', superimposed_img)
cv2.waitKey(0)
cv2.destroyAllWindows()
Enter fullscreen mode Exit fullscreen mode

By following these steps, you can effectively implement Grad-CAM in PyTorch to visualize and interpret the decision-making process of convolutional neural networks.

Also Read: Steps to Apply Grad-CAM to Deep-Learning Models

Grad-CAM is a powerful tool for visualizing and understanding the decisions of deep learning models. By providing insights into which parts of an image were most influential in a model’s prediction, Grad-CAM enhances model interpretability, trust, and transparency.

As a leading AI & ML software development company, CodeTrade leverages such advanced techniques to deliver robust and explainable AI solutions.

Top comments (0)