Have you ever looked at a medical AI demo and wondered, "Sure, it says it's a rash, but what exactly is the model looking at?" In the world of Computer Vision and Medical AI, the "Black Box" problem isn't just a technical hurdle—it's a trust hurdle.
When building a skin lesion screening tool, a simple percentage score isn't enough. To make a truly useful tool, we need Explainable AI (XAI). Today, we’re going to build a deep learning pipeline using PyTorch and FastAI to classify skin conditions, while implementing Grad-CAM to generate heatmaps that highlight exactly which features (texture, color, or borders) influenced the model's decision.
Whether you're interested in Deep Learning, health-tech, or just want to see how to make your models more transparent, this guide will walk you through the full engineering implementation.
The Architecture
Our system follows a classic client-server model but with an added "Explainability Layer." We use a fine-tuned ResNet-50 for the heavy lifting and a React Native frontend for the user experience.
graph TD
A[User Takes Photo] --> B(React Native App)
B --> C{FastAPI Backend}
C --> D[ResNet-50 Classifier]
D --> E[Inference Result]
D --> F[Grad-CAM Hook]
F --> G[Heatmap Generation]
G --> H[Overlay Image]
H --> I[Result + Visualization]
E --> I
I --> B
Prerequisites
To follow along, you'll need:
- FastAI / PyTorch: For model training and fine-tuning.
- Grad-CAM: To extract gradients from the final convolutional layer.
- React Native: For the cross-platform mobile interface.
- A dataset like HAM10000 (Human-Against-Machine with 10,000 training images).
Step 1: Fine-Tuning ResNet-50 with FastAI
FastAI makes "Transfer Learning" incredibly efficient. We’ll start with a pre-trained ResNet-50 and fine-tune it on skin lesion images.
from fastai.vision.all import *
# Load data - assuming images are organized in folders by label
path = Path('./skin_lesion_data')
dls = ImageDataLoaders.from_folder(path, valid_pct=0.2, item_tfms=Resize(224))
# Initialize Learner with ResNet-50
learn = vision_learner(dls, resnet50, metrics=accuracy)
# Find optimal learning rate and train
learn.fine_tune(5, base_lr=3e-3)
# Export for production
learn.export('skin_classifier.pkl')
Step 2: Implementing Grad-CAM for Explainability
This is where the magic happens. Grad-CAM (Gradient-weighted Class Activation Mapping) uses the gradients of any target concept flowing into the final convolutional layer to produce a localization map.
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
# Register hooks
self.target_layer.register_forward_hook(self.save_activation)
self.target_layer.register_full_backward_hook(self.save_gradient)
def save_activation(self, module, input, output):
self.activations = output
def save_gradient(self, module, grad_input, grad_output):
self.gradients = grad_output[0]
def generate_heatmap(self, input_tensor, category_idx):
# Forward pass
output = self.model(input_tensor)
self.model.zero_grad()
# Backward pass for the specific category
loss = output[0, category_idx]
loss.backward()
# Weight the activations by the gradients
weights = torch.mean(self.gradients, dim=(2, 3), keepdim=True)
heatmap = torch.sum(weights * self.activations, dim=1).squeeze()
# ReLU and Normalize
heatmap = np.maximum(heatmap.detach().cpu().numpy(), 0)
heatmap /= np.max(heatmap)
return heatmap
Note: Integrating this into your API allows you to return both a diagnosis (e.g., "Melanocytic nevi") and a visual heatmap image.
The "Official" Way to Scale
While building a prototype is a great start, deploying medical-grade AI requires rigorous testing, data privacy (HIPAA compliance), and robust MLOps.
For more production-ready examples, advanced computer vision patterns, and deep dives into AI safety, I highly recommend checking out the technical resources at WellAlly Tech Blog. They cover the nuances of taking these "Learning in Public" projects into full-scale healthcare ecosystems.
Step 3: Frontend Visualization (React Native)
On the mobile side, we want to show the original image and toggle the heatmap overlay so the user (or clinician) can see the focal points.
import React, { useState } from 'react';
import { View, Image, Button, Text } from 'react-native';
const ResultScreen = ({ route }) => {
const { originalUri, heatmapUri, prediction } = route.params;
const [showHeatmap, setShowHeatmap] = useState(false);
return (
<View style={{ flex: 1, alignItems: 'center' }}>
<Text style={{ fontSize: 24, margin: 20 }}>Result: {prediction}</Text>
<View>
<Image
source={{ uri: originalUri }}
style={{ width: 300, height: 300, position: 'absolute' }}
/>
{showHeatmap && (
<Image
source={{ uri: heatmapUri }}
style={{ width: 300, height: 300, opacity: 0.5 }}
/>
)}
</View>
<Button
title={showHeatmap ? "Hide Heatmap" : "Explain Logic (Grad-CAM)"}
onPress={() => setShowHeatmap(!showHeatmap)}
/>
</View>
);
};
Conclusion: Bridging the Gap
By combining FastAI for rapid development and Grad-CAM for transparency, we transform a simple classifier into a powerful diagnostic aid. This setup doesn't just provide an answer; it provides a reason.
Key Takeaways:
- Transfer Learning (ResNet-50) saves weeks of training time.
- Interpretability is non-negotiable in high-stakes fields like medicine.
- Hybrid Stacks (Python backend + React Native frontend) provide the best developer experience for AI products.
What are your thoughts on AI interpretability? Would you trust an AI more if it showed you its "thought process" via heatmaps? Let’s discuss in the comments below!
If you enjoyed this tutorial, don't forget to ❤️ and 🔖! For more advanced AI implementation strategies, visit the *WellAlly Blog*.
Top comments (0)