DEV Community

Cover image for I fused 1,500 GPU dispatches into one. Here's what happened.
Ahmet Barış Günaydın
Ahmet Barış Günaydın

Posted on

I fused 1,500 GPU dispatches into one. Here's what happened.

Every ML framework does GPU computation the same way: send a task to the GPU, wait, send the next one, wait, repeat. For a 1,500-step simulation, that's 22,500 separate GPU commands per generation.

I tried something different. I wrote a WebGPU compute shader that runs the entire 1,500-step simulation in a single GPU dispatch. No round-trips. No waiting. The GPU just loops internally.

The results (same hardware, no tricks)

On the same Apple M2 Pro:

  • WebGPU (Chrome): 46.2 gen/s
  • PyTorch MPS: 0.29 gen/s
  • That's 159x.

On embarrassingly parallel workloads (Rastrigin), they're basically tied (1.06x). The advantage is specific to sequential workloads — simulations, RL rollouts, trading strategies — where each step depends on the previous one.

Why can't PyTorch just do this?

I tested torch.compile with the Inductor backend. It tries to unroll the loop into a single computation graph:

Timesteps Result
500 Works, 2x speedup, 25s compile
1,000 RecursionError
5,000 OOM killed after 30 min

The compiler crashes because it tries to represent the entire loop as a static graph. WebGPU's approach is different — the shader contains an actual for loop that runs on the GPU. Simple, but it works.

JAX gets closer but not all the way

JAX with lax.scan + vmap on a Tesla T4 achieves 6.43 gen/s on the same financial simulation — 13x over PyTorch CUDA on the same T4. XLA does fuse the loop. But it still ends up 7.2x slower than the hand-fused WebGPU shader, likely because the XLA kernel still has per-step overhead internally (register spills, memory traffic).

At shorter episodes (L=500, Acrobot), JAX nearly closes the gap (1.29x). The fusion advantage scales with episode length.

The browser overhead is real but small

I ran the exact same WGSL shader through wgpu-native (Rust, no browser). Native Metal: 326.5 gen/s. Chrome WebGPU: 170.3 gen/s. That's a 1.92x browser tax (48% overhead).

But here's the weird part: PyTorch MPS (160.5 gen/s) is slower than WebGPU in Chrome (170.3 gen/s) on parallel workloads. The browser's overhead is smaller than PyTorch's framework overhead.

Try it yourself

I built a benchmark site where you can test your GPU in the browser. No install, no account:

gpubench.dev

~300 people have run it so far — Apple, AMD, NVIDIA, Intel GPUs all working in Chrome. The data is starting to paint an interesting picture of WebGPU performance across real hardware.

The paper

Full methodology, 10 tables, same-hardware comparisons, ablation study, all numbers verified:

doi.org/10.5281/zenodo.19335214

Code (WGSL shaders, Puppeteer benchmarks, Python baselines):

github.com/abgnydn/webgpu-kernel-fusion

What I'd love to know

Has anyone else used WebGPU compute shaders for non-graphics workloads? I'm curious what other sequential problems would benefit from this fusion pattern — RL environments, Monte Carlo simulations, agent-based models?


Acknowledgment: Drafting assistance by Claude. All experiments, benchmarks, and code by the author.

Top comments (0)