In an era where privacy and latency are the biggest bottlenecks for AI adoption, on-device machine learning is emerging as a game-changer for healthcare. Imagine performing skin lesion classification—identifying potential issues like melanoma—directly on a smartphone without ever sending a single pixel to the cloud. By leveraging the power of MobileNetV3, ONNX Runtime, and Flutter, we can create a high-performance screening tool that works offline and in real-time. 🚀
This tutorial dives deep into the engineering pipeline of fine-tuning a lightweight computer vision model and deploying it to a mobile environment. We’ll focus on the synergy between on-device inference and high-accuracy diagnostic models. If you are looking for production-grade insights on edge AI, the experts over at WellAlly Tech Blog provide fantastic deep-dives into scaling these types of medical-grade architectures. 🩺
🏗️ The System Architecture
Before we touch the code, let’s look at the data flow. We start with a heavy PyTorch model, compress it into the universal ONNX format, and then use the Android NDK to run it at native speeds within a Flutter wrapper.
graph TD
A[Raw Dataset: HAM10000] --> B[PyTorch Fine-tuning: MobileNetV3]
B --> C[ONNX Export & Quantization]
C --> D[Mobile Deployment]
subgraph "On-Device Inference"
D --> E[Flutter UI - Camera Stream]
E --> F[Android NDK / C++ Layer]
F --> G[ONNX Runtime Engine]
G --> H[Result: Lesion Type + Confidence]
end
H --> E
🛠️ Prerequisites
To follow along, you'll need:
- Python 3.10+ (PyTorch, ONNX, Optimum)
- Flutter SDK
- Android NDK (for low-latency C++ bindings)
- A dataset like HAM10000 (Human Against Machine) for skin lesion images.
Step 1: Fine-tuning MobileNetV3 with PyTorch
MobileNetV3 is specifically designed for mobile CPUs. It uses platform-aware Architecture Search (NAS) to find the best balance between accuracy and latency.
import torch
import torch.nn as nn
from torchvision import models
def get_skin_model(num_classes=7):
# We use the 'small' version for maximum speed on mobile
model = models.mobilenet_v3_small(weights='IMAGENET1K_V1')
# Freeze the backbone for initial training
for param in model.parameters():
param.requires_grad = False
# Replace the classifier head for skin lesion categories
# (e.g., Melanoma, Basal cell carcinoma, etc.)
last_channel = model.classifier[0].in_features
model.classifier = nn.Sequential(
nn.Linear(last_channel, 1024),
nn.Hardswish(inplace=True),
nn.Dropout(p=0.2, inplace=True),
nn.Linear(1024, num_classes)
)
return model
model = get_skin_model()
print("MobileNetV3 ready for fine-tuning! 🚀")
Step 2: Exporting to ONNX and Quantization
Once trained, we don't want to ship a bulky .pth file. We export it to ONNX (Open Neural Network Exchange) and apply INT8 quantization to reduce the model size by ~75% with minimal accuracy loss.
import torch.onnx
# Dummy input matching our camera resolution/preprocessing (224x224)
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"skin_classifier.onnx",
export_params=True,
opset_version=12,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)
# Pro Tip: Use ONNX Runtime tools to quantize
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic("skin_classifier.onnx", "skin_classifier_quant.onnx", weight_type=QuantType.QUInt8)
Step 3: Bridging to Flutter via Android NDK
While Flutter handles the UI, we need the Android NDK and C++ to interface with ONNX Runtime efficiently. This ensures we aren't bottlenecked by the Dart VM's garbage collector when processing 30 frames per second.
The C++ Interface (Simplified)
#include <onnxruntime_cxx_api.h>
// Initialize the session
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "SkinClassifier");
Ort::SessionOptions session_options;
Ort::Session session(env, model_path, session_options);
// Run Inference
void run_inference(float* input_data) {
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_data, ...);
auto output_tensors = session.Run(Ort::RunOptions{nullptr}, input_node_names, &input_tensor, 1, output_node_names, 1);
// Process results...
}
Step 4: The Flutter UI Integration 🥑
In Flutter, we use a MethodChannel or FFI (Foreign Function Interface) to pass the camera image buffer to our C++ layer.
// Dart side: Passing image buffer to Native
static const platform = MethodChannel('tech.wellally.skin/inference');
Future<void> analyzeImage(Uint8List bytes) async {
try {
final List result = await platform.invokeMethod('predict', {"data": bytes});
setState(() {
_prediction = result[0]; // e.g., "Melanocytic nevi"
_confidence = result[1]; // e.g., 0.98
});
} on PlatformException catch (e) {
print("Failed to run inference: ${e.message}");
}
}
🌟 The "Official" Way to Build Medical AI
While this tutorial provides a solid foundation for a prototype, building production-ready medical screening tools requires rigorous attention to preprocessing pipelines, model explainability (Grad-CAM), and secure data handling.
For advanced architectural patterns and more production-ready examples of on-device vision, I highly recommend exploring the resources at https://www.wellally.tech/blog. They offer excellent insights into optimizing ML models for real-world constraints.
Conclusion
By moving our skin lesion screening logic from the cloud to the device using MobileNetV3 and ONNX Runtime, we’ve achieved:
- Privacy: No medical images leave the device.
- Speed: Millisecond-level inference without network latency.
- Accessibility: The app works in remote areas without internet.
On-device AI is the future of personalized healthcare. What are you planning to build at the edge? Let me know in the comments! 👇
Top comments (0)