DEV Community

Discussion on: I Rebuilt Karpathy's NanoChat in JAX. Here's What XLA Gets Right and What It Gets Dead Wrong.

Collapse
 
itskondrat profile image
Mykola Kondratiuk

JAX hits the wall where inference matters. no vLLM is a real constraint once you're serving - not a footnote. XLA compilation wins for TPU training throughput but I'd reach for PyTorch the moment I need production serving.

Collapse
 
ainaomotayo profile image
Omotayo Aina Google Developer Experts

Hi @itskondrat , correct, vLLM does not support JAX on GPU, so that constraint stands for the majority of production deployments. On TPU the picture is different: vLLM now supports both JAX and PyTorch through the tpu-inference plugin, which is where JetStream's core functionality landed after Google archived that project on February, 2026. JAX still powers the performance underneath on TPU, just unified under the vLLM interface. For GPU production serving there is no JAX equivalent to vLLM today.

nanochat-jax ships a FastAPI server with SSE streaming and a chat UI, sized for demo use and generation quality checks, not production traffic. Weights going into real GPU serving should move to a PyTorch checkpoint and run through vLLM or TGI.

Training and scaling research is where JAX earns its place in this project. On GPU inference, the tooling gap is real.