DEV Community

TildAlice
TildAlice

Posted on • Originally published at tildalice.io

Why JAX Feels Faster Than PyTorch (and When It Isn't)

The JIT Wall

Run this PyTorch code and time it:

import torch
import time

def matmul_chain(x):
    for _ in range(100):
        x = x @ x
    return x

x = torch.randn(512, 512, device='cuda')
start = time.perf_counter()
y = matmul_chain(x)
torch.cuda.synchronize()
print(f"PyTorch: {time.perf_counter() - start:.4f}s")
Enter fullscreen mode Exit fullscreen mode

Now the JAX equivalent:

import jax
import jax.numpy as jnp
import time

@jax.jit
def matmul_chain(x):
    for _ in range(100):
        x = x @ x
    return x

x = jax.random.normal(jax.random.PRNGKey(0), (512, 512))
start = time.perf_counter()
y = matmul_chain(x).block_until_ready()
print(f"JAX: {time.perf_counter() - start:.4f}s")
Enter fullscreen mode Exit fullscreen mode

On my RTX 3090, PyTorch takes ~0.0087s. JAX clocks in at ~0.0031s.

That's nearly 3x faster. But here's the thing: run JAX's version a second time without the @jax.jit decorator, and you'll see ~0.0095s — actually slower than PyTorch. The speedup isn't magic. It's XLA, JAX's compiler, turning your Python loop into a single fused GPU kernel. PyTorch 2.0 has torch.compile() now, which uses TorchInductor to pull off similar tricks, but JAX had JIT compilation baked in from day one.

The performance gap you feel when using JAX comes down to how often the compiler can help you, and how much overhead you're willing to tolerate.


Continue reading the full article on TildAlice

Top comments (0)