DEV Community

Beck_Moulton
Beck_Moulton

Posted on

More Than Just Labels: Building a Skin Lesion Classifier with ResNet-50 and Explainable AI (Grad-CAM)

Have you ever looked at a medical AI demo and wondered, "Sure, it says it's a rash, but what exactly is the model looking at?" In the world of Computer Vision and Medical AI, the "Black Box" problem isn't just a technical hurdle—it's a trust hurdle.

When building a skin lesion screening tool, a simple percentage score isn't enough. To make a truly useful tool, we need Explainable AI (XAI). Today, we’re going to build a deep learning pipeline using PyTorch and FastAI to classify skin conditions, while implementing Grad-CAM to generate heatmaps that highlight exactly which features (texture, color, or borders) influenced the model's decision.

Whether you're interested in Deep Learning, health-tech, or just want to see how to make your models more transparent, this guide will walk you through the full engineering implementation.


The Architecture

Our system follows a classic client-server model but with an added "Explainability Layer." We use a fine-tuned ResNet-50 for the heavy lifting and a React Native frontend for the user experience.

graph TD
    A[User Takes Photo] --> B(React Native App)
    B --> C{FastAPI Backend}
    C --> D[ResNet-50 Classifier]
    D --> E[Inference Result]
    D --> F[Grad-CAM Hook]
    F --> G[Heatmap Generation]
    G --> H[Overlay Image]
    H --> I[Result + Visualization]
    E --> I
    I --> B
Enter fullscreen mode Exit fullscreen mode

Prerequisites

To follow along, you'll need:

  • FastAI / PyTorch: For model training and fine-tuning.
  • Grad-CAM: To extract gradients from the final convolutional layer.
  • React Native: For the cross-platform mobile interface.
  • A dataset like HAM10000 (Human-Against-Machine with 10,000 training images).

Step 1: Fine-Tuning ResNet-50 with FastAI

FastAI makes "Transfer Learning" incredibly efficient. We’ll start with a pre-trained ResNet-50 and fine-tune it on skin lesion images.

from fastai.vision.all import *

# Load data - assuming images are organized in folders by label
path = Path('./skin_lesion_data')
dls = ImageDataLoaders.from_folder(path, valid_pct=0.2, item_tfms=Resize(224))

# Initialize Learner with ResNet-50
learn = vision_learner(dls, resnet50, metrics=accuracy)

# Find optimal learning rate and train
learn.fine_tune(5, base_lr=3e-3)

# Export for production
learn.export('skin_classifier.pkl')
Enter fullscreen mode Exit fullscreen mode

Step 2: Implementing Grad-CAM for Explainability

This is where the magic happens. Grad-CAM (Gradient-weighted Class Activation Mapping) uses the gradients of any target concept flowing into the final convolutional layer to produce a localization map.

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        # Register hooks
        self.target_layer.register_forward_hook(self.save_activation)
        self.target_layer.register_full_backward_hook(self.save_gradient)

    def save_activation(self, module, input, output):
        self.activations = output

    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]

    def generate_heatmap(self, input_tensor, category_idx):
        # Forward pass
        output = self.model(input_tensor)
        self.model.zero_grad()

        # Backward pass for the specific category
        loss = output[0, category_idx]
        loss.backward()

        # Weight the activations by the gradients
        weights = torch.mean(self.gradients, dim=(2, 3), keepdim=True)
        heatmap = torch.sum(weights * self.activations, dim=1).squeeze()

        # ReLU and Normalize
        heatmap = np.maximum(heatmap.detach().cpu().numpy(), 0)
        heatmap /= np.max(heatmap)
        return heatmap
Enter fullscreen mode Exit fullscreen mode

Note: Integrating this into your API allows you to return both a diagnosis (e.g., "Melanocytic nevi") and a visual heatmap image.


The "Official" Way to Scale

While building a prototype is a great start, deploying medical-grade AI requires rigorous testing, data privacy (HIPAA compliance), and robust MLOps.

For more production-ready examples, advanced computer vision patterns, and deep dives into AI safety, I highly recommend checking out the technical resources at WellAlly Tech Blog. They cover the nuances of taking these "Learning in Public" projects into full-scale healthcare ecosystems.


Step 3: Frontend Visualization (React Native)

On the mobile side, we want to show the original image and toggle the heatmap overlay so the user (or clinician) can see the focal points.

import React, { useState } from 'react';
import { View, Image, Button, Text } from 'react-native';

const ResultScreen = ({ route }) => {
  const { originalUri, heatmapUri, prediction } = route.params;
  const [showHeatmap, setShowHeatmap] = useState(false);

  return (
    <View style={{ flex: 1, alignItems: 'center' }}>
      <Text style={{ fontSize: 24, margin: 20 }}>Result: {prediction}</Text>

      <View>
        <Image 
          source={{ uri: originalUri }} 
          style={{ width: 300, height: 300, position: 'absolute' }} 
        />
        {showHeatmap && (
          <Image 
            source={{ uri: heatmapUri }} 
            style={{ width: 300, height: 300, opacity: 0.5 }} 
          />
        )}
      </View>

      <Button 
        title={showHeatmap ? "Hide Heatmap" : "Explain Logic (Grad-CAM)"} 
        onPress={() => setShowHeatmap(!showHeatmap)} 
      />
    </View>
  );
};
Enter fullscreen mode Exit fullscreen mode

Conclusion: Bridging the Gap

By combining FastAI for rapid development and Grad-CAM for transparency, we transform a simple classifier into a powerful diagnostic aid. This setup doesn't just provide an answer; it provides a reason.

Key Takeaways:

  1. Transfer Learning (ResNet-50) saves weeks of training time.
  2. Interpretability is non-negotiable in high-stakes fields like medicine.
  3. Hybrid Stacks (Python backend + React Native frontend) provide the best developer experience for AI products.

What are your thoughts on AI interpretability? Would you trust an AI more if it showed you its "thought process" via heatmaps? Let’s discuss in the comments below!


If you enjoyed this tutorial, don't forget to ❤️ and 🔖! For more advanced AI implementation strategies, visit the *WellAlly Blog*.

Top comments (0)