Every ML framework does GPU computation the same way: send a task to the GPU, wait, send the next one, wait, repeat. For a 1,500-step simulation, that's 22,500 separate GPU commands per generation.
I tried something different. I wrote a WebGPU compute shader that runs the entire 1,500-step simulation in a single GPU dispatch. No round-trips. No waiting. The GPU just loops internally.
The results (same hardware, no tricks)
On the same Apple M2 Pro:
- WebGPU (Chrome): 46.2 gen/s
- PyTorch MPS: 0.29 gen/s
- That's 159x.
On embarrassingly parallel workloads (Rastrigin), they're basically tied (1.06x). The advantage is specific to sequential workloads — simulations, RL rollouts, trading strategies — where each step depends on the previous one.
Why can't PyTorch just do this?
I tested torch.compile with the Inductor backend. It tries to unroll the loop into a single computation graph:
| Timesteps | Result |
|---|---|
| 500 | Works, 2x speedup, 25s compile |
| 1,000 | RecursionError |
| 5,000 | OOM killed after 30 min |
The compiler crashes because it tries to represent the entire loop as a static graph. WebGPU's approach is different — the shader contains an actual for loop that runs on the GPU. Simple, but it works.
JAX gets closer but not all the way
JAX with lax.scan + vmap on a Tesla T4 achieves 6.43 gen/s on the same financial simulation — 13x over PyTorch CUDA on the same T4. XLA does fuse the loop. But it still ends up 7.2x slower than the hand-fused WebGPU shader, likely because the XLA kernel still has per-step overhead internally (register spills, memory traffic).
At shorter episodes (L=500, Acrobot), JAX nearly closes the gap (1.29x). The fusion advantage scales with episode length.
The browser overhead is real but small
I ran the exact same WGSL shader through wgpu-native (Rust, no browser). Native Metal: 326.5 gen/s. Chrome WebGPU: 170.3 gen/s. That's a 1.92x browser tax (48% overhead).
But here's the weird part: PyTorch MPS (160.5 gen/s) is slower than WebGPU in Chrome (170.3 gen/s) on parallel workloads. The browser's overhead is smaller than PyTorch's framework overhead.
Try it yourself
I built a benchmark site where you can test your GPU in the browser. No install, no account:
~300 people have run it so far — Apple, AMD, NVIDIA, Intel GPUs all working in Chrome. The data is starting to paint an interesting picture of WebGPU performance across real hardware.
The paper
Full methodology, 10 tables, same-hardware comparisons, ablation study, all numbers verified:
doi.org/10.5281/zenodo.19335214
Code (WGSL shaders, Puppeteer benchmarks, Python baselines):
github.com/abgnydn/webgpu-kernel-fusion
What I'd love to know
Has anyone else used WebGPU compute shaders for non-graphics workloads? I'm curious what other sequential problems would benefit from this fusion pattern — RL environments, Monte Carlo simulations, agent-based models?
Acknowledgment: Drafting assistance by Claude. All experiments, benchmarks, and code by the author.
Top comments (0)