---
title: "WebGPU Compute Shaders: On-Device LLM Inference Beyond NNAPI"
published: true
description: "Build a hybrid architecture using WebGPU compute shaders in Android WebViews for GPU-accelerated LLM inference that bypasses NNAPI limitations."
tags: android, kotlin, architecture, mobile
canonical_url: https://blog.mvpfactory.co/webgpu-compute-shaders-on-device-llm-inference-beyond-nnapi
---
## What We Will Build
In this tutorial, I'll walk you through a hybrid on-device LLM inference pipeline where WebGPU compute shaders handle attention-layer matrix multiplications via Android WebView, while CPU threads manage non-matmul operations. By the end, you'll have a working split architecture, a tuned WGSL compute shader for quantized GEMM, and a strategy for minimizing bridge overhead.
## Prerequisites
- Android 10+ with Chrome 113+ WebView (ships WebGPU support)
- Kotlin project targeting a recent `compileSdk`
- A quantized LLM in the 1–4B parameter range (INT4)
- Familiarity with Android `WebView` and coroutines
## Step 1: Understand Why NNAPI Falls Short
Before writing code, let me show you the problem. NNAPI delegates to the best accelerator on paper — GPU, DSP, NPU. In practice, you hit three walls:
1. **Operator coverage gaps.** Custom or fused ops silently fall back to CPU.
2. **Vendor-specific bugs.** Identical models produce different results on Qualcomm vs. MediaTek vs. Samsung Exynos.
3. **Quantization inconsistencies.** INT8/INT4 support varies wildly across HAL implementations.
For transformer attention layers — batched GEMM, softmax, layer normalization — NNAPI's coverage is incomplete on most shipping devices. WebGPU gives you a standardized GPU compute interface updated via the Play Store, no vendor HAL required.
| Factor | NNAPI | WebGPU via WebView |
|---|---|---|
| GPU access | Via vendor HAL | Direct via standardized API |
| Operator coverage | Vendor-dependent, partial | You write the shaders, full control |
| Quantization support | INT8 on some, INT4 rare | Custom, implement what you need |
| Update mechanism | OS/firmware update | Play Store WebView update |
| Debugging | Opaque vendor stack | Chrome DevTools, shader logging |
## Step 2: Split the Pipeline
Here is the pattern I use in every project — don't run the entire LLM pipeline in WebGPU. Split at the GEMM boundary.
**WebGPU handles:** QKV projections, attention score computation, feed-forward GEMM — dense matrix multiplies on quantized weights.
**CPU threads handle:** tokenization, embedding lookups, layer norm, residual connections, sampling — memory-bound or sequential ops.
kotlin
class HybridLLMEngine(private val webView: WebView) {
suspend fun generateToken(inputIds: IntArray): Int {
val embeddings = cpuEmbeddingLookup(inputIds)
val hiddenState = webView.evaluateJavascriptSuspend(
"runTransformerBlock(${embeddings.toJSArrayBuffer()})"
)
return cpuSampleFromLogits(hiddenState)
}
}
## Step 3: Write the Compute Shader
Here is the minimal setup to get this working — a WGSL compute shader for quantized INT4 × FP16 matrix multiplication:
wgsl
@compute @workgroup_size(8, 8, 1)
fn matmul_q4_f16(
@builtin(global_invocation_id) gid: vec3
) {
let row = gid.x;
let col = gid.y;
var acc: f32 = 0.0;
for (var k: u32 = 0u; k < K / 8u; k = k + 1u) {
let packed = weights[row * (K / 8u) + k];
let input_vec = activations[k * 8u];
acc += dequantDotProduct(packed, input_vec);
}
output[row * N + col] = acc;
}
## Step 4: Tune Workgroup Sizes
Workgroup size is the single biggest performance lever. Mobile GPUs differ from desktop — Adreno operates on 64-wide waves, Mali on 16-wide warps.
- Start with `@workgroup_size(8, 8, 1)` — 64 threads, aligns with Adreno.
- Profile with `@workgroup_size(4, 4, 1)` — 16 threads, better for Mali.
- Query adapter limits at runtime and select the appropriate shader variant.
I've seen 2–3x differences on the same device just from workgroup sizing. Ship at least two variants and select based on `GPUAdapterInfo`.
## Step 5: Minimize Bridge Crossings
The JS-to-native bridge is your bottleneck. Run all transformer layers in a single WebGPU dispatch — never bounce back to native between layers.
kotlin
// Bad: cross bridge per layer (12 round trips for 12-layer model)
// Good: single dispatch, all layers GPU-side
webView.evaluateJavascript("runAllLayers(inputBuffer, 12)")
Use `GPUBuffer` with `MAP_READ` only on the final output. Intermediate buffers should be `STORAGE` only — never mapped, never crossing the bridge.
## Gotchas
- **The docs don't mention this, but** workgroup size defaults are almost never optimal on mobile. Always profile per GPU family — skipping this step leaves 2–3x performance on the table.
- **Model size vs. VRAM.** Most mobile GPUs cap around 1–3 GB shared memory. INT4 quantization in the 1–4B parameter range is the sweet spot.
- **WebView version gaps.** Devices on Android < 10 or with outdated WebView won't have WebGPU. Feature-detect before committing to this path.
- **Sub-50ms latency targets.** The JS bridge adds measurable overhead. If you need sub-50ms per token, this architecture may not be the right fit.
- **Run `nnapi-check` first.** If fewer than 20% of ops fall back to CPU on your target devices, NNAPI might still win. Audit before you build.
## Wrapping Up
Here is the gotcha that will save you hours: predictable GPU execution beats unpredictable fallback-to-CPU every time for LLM workloads where each token generation involves hundreds of GEMM operations. Audit your NNAPI operator coverage, split at the GEMM boundary, tune your workgroups per GPU family, and batch all layers into a single dispatch. That's the hybrid pipeline that actually ships.
Top comments (0)