DEV Community

wellallyTech
wellallyTech

Posted on

From Pixels to Diagnosis: Building a Lightning-Fast Skin Lesion Classifier with MobileNetV3 and ONNX Runtime

In an era where privacy and latency are the biggest bottlenecks for AI adoption, on-device machine learning is emerging as a game-changer for healthcare. Imagine performing skin lesion classification—identifying potential issues like melanoma—directly on a smartphone without ever sending a single pixel to the cloud. By leveraging the power of MobileNetV3, ONNX Runtime, and Flutter, we can create a high-performance screening tool that works offline and in real-time. 🚀

This tutorial dives deep into the engineering pipeline of fine-tuning a lightweight computer vision model and deploying it to a mobile environment. We’ll focus on the synergy between on-device inference and high-accuracy diagnostic models. If you are looking for production-grade insights on edge AI, the experts over at WellAlly Tech Blog provide fantastic deep-dives into scaling these types of medical-grade architectures. 🩺

🏗️ The System Architecture

Before we touch the code, let’s look at the data flow. We start with a heavy PyTorch model, compress it into the universal ONNX format, and then use the Android NDK to run it at native speeds within a Flutter wrapper.

graph TD
    A[Raw Dataset: HAM10000] --> B[PyTorch Fine-tuning: MobileNetV3]
    B --> C[ONNX Export & Quantization]
    C --> D[Mobile Deployment]
    subgraph "On-Device Inference"
    D --> E[Flutter UI - Camera Stream]
    E --> F[Android NDK / C++ Layer]
    F --> G[ONNX Runtime Engine]
    G --> H[Result: Lesion Type + Confidence]
    end
    H --> E
Enter fullscreen mode Exit fullscreen mode

🛠️ Prerequisites

To follow along, you'll need:

  • Python 3.10+ (PyTorch, ONNX, Optimum)
  • Flutter SDK
  • Android NDK (for low-latency C++ bindings)
  • A dataset like HAM10000 (Human Against Machine) for skin lesion images.

Step 1: Fine-tuning MobileNetV3 with PyTorch

MobileNetV3 is specifically designed for mobile CPUs. It uses platform-aware Architecture Search (NAS) to find the best balance between accuracy and latency.

import torch
import torch.nn as nn
from torchvision import models

def get_skin_model(num_classes=7):
    # We use the 'small' version for maximum speed on mobile
    model = models.mobilenet_v3_small(weights='IMAGENET1K_V1')

    # Freeze the backbone for initial training
    for param in model.parameters():
        param.requires_grad = False

    # Replace the classifier head for skin lesion categories
    # (e.g., Melanoma, Basal cell carcinoma, etc.)
    last_channel = model.classifier[0].in_features
    model.classifier = nn.Sequential(
        nn.Linear(last_channel, 1024),
        nn.Hardswish(inplace=True),
        nn.Dropout(p=0.2, inplace=True),
        nn.Linear(1024, num_classes)
    )
    return model

model = get_skin_model()
print("MobileNetV3 ready for fine-tuning! 🚀")
Enter fullscreen mode Exit fullscreen mode

Step 2: Exporting to ONNX and Quantization

Once trained, we don't want to ship a bulky .pth file. We export it to ONNX (Open Neural Network Exchange) and apply INT8 quantization to reduce the model size by ~75% with minimal accuracy loss.

import torch.onnx

# Dummy input matching our camera resolution/preprocessing (224x224)
dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model, 
    dummy_input, 
    "skin_classifier.onnx",
    export_params=True,
    opset_version=12,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)

# Pro Tip: Use ONNX Runtime tools to quantize
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic("skin_classifier.onnx", "skin_classifier_quant.onnx", weight_type=QuantType.QUInt8)
Enter fullscreen mode Exit fullscreen mode

Step 3: Bridging to Flutter via Android NDK

While Flutter handles the UI, we need the Android NDK and C++ to interface with ONNX Runtime efficiently. This ensures we aren't bottlenecked by the Dart VM's garbage collector when processing 30 frames per second.

The C++ Interface (Simplified)

#include <onnxruntime_cxx_api.h>

// Initialize the session
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "SkinClassifier");
Ort::SessionOptions session_options;
Ort::Session session(env, model_path, session_options);

// Run Inference
void run_inference(float* input_data) {
    auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
    Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_data, ...);

    auto output_tensors = session.Run(Ort::RunOptions{nullptr}, input_node_names, &input_tensor, 1, output_node_names, 1);
    // Process results...
}
Enter fullscreen mode Exit fullscreen mode

Step 4: The Flutter UI Integration 🥑

In Flutter, we use a MethodChannel or FFI (Foreign Function Interface) to pass the camera image buffer to our C++ layer.

// Dart side: Passing image buffer to Native
static const platform = MethodChannel('tech.wellally.skin/inference');

Future<void> analyzeImage(Uint8List bytes) async {
  try {
    final List result = await platform.invokeMethod('predict', {"data": bytes});
    setState(() {
      _prediction = result[0]; // e.g., "Melanocytic nevi"
      _confidence = result[1]; // e.g., 0.98
    });
  } on PlatformException catch (e) {
    print("Failed to run inference: ${e.message}");
  }
}
Enter fullscreen mode Exit fullscreen mode

🌟 The "Official" Way to Build Medical AI

While this tutorial provides a solid foundation for a prototype, building production-ready medical screening tools requires rigorous attention to preprocessing pipelines, model explainability (Grad-CAM), and secure data handling.

For advanced architectural patterns and more production-ready examples of on-device vision, I highly recommend exploring the resources at https://www.wellally.tech/blog. They offer excellent insights into optimizing ML models for real-world constraints.


Conclusion

By moving our skin lesion screening logic from the cloud to the device using MobileNetV3 and ONNX Runtime, we’ve achieved:

  1. Privacy: No medical images leave the device.
  2. Speed: Millisecond-level inference without network latency.
  3. Accessibility: The app works in remote areas without internet.

On-device AI is the future of personalized healthcare. What are you planning to build at the edge? Let me know in the comments! 👇

Top comments (0)