DEV Community

Hiroshi Toyama
Hiroshi Toyama

Posted on

Why TPUs Aren't Popular (Even Though They're Cheaper Per Token)

If you only look at the spec sheet, the TPU story is overwhelming: lower cost-per-token, dramatically better watts-per-token, deterministic latency. Trainium tells the same story. And yet a large share of the industry — including most of the inference traffic behind consumer chat UIs like ChatGPT — still runs on NVIDIA. The gap between "cheaper on paper" and "what people actually deploy" is not a marketing failure. It's an architectural tax that systolic-array silicon charges you in code, pipelines, and org structure. This post is about where that tax comes from and why only a handful of companies can afford to pay it.

The one architectural fact that explains everything: static shapes

NVIDIA GPUs are SIMT (Single Instruction, Multiple Threads) processors. They schedule threads dynamically at runtime and page memory on demand. TPUs and AWS Trainium are not GPUs — they are systolic arrays: a grid of multiply-accumulate units wired directly to their neighbors, fed by an ahead-of-time compiler (XLA for TPU, the Neuron compiler for Trainium).

A systolic array hits peak utilization only when the shape of the data flowing through it is fixed at compile time. Weights are loaded once and stay stationary in the processing elements; activations slide through like a bucket brigade. Change the sequence length or batch size by even one token and the data routes and memory addresses have to be recomputed — which means the compiler has to generate a new binary.

That single constraint is the source of every downstream pain. Here's what it forces on you at inference time:

Runtime input NVIDIA (dynamic) TPU / Trainium (static)
Larger than the compiled bucket Handled by dynamic allocation Shape-mismatch crash
Smaller than the bucket Handled with no waste JIT recompile stall (minutes) or zero-pad waste
New, unseen length Just runs New binary must exist, or it stalls

So before any token reaches the chip, you need an answer to: "what shape is this, and which precompiled binary does it route to?" On NVIDIA you never ask that question.

The dynamic vs. static analogy: Python vs. Java

The cleanest mental model: NVIDIA is Python, TPU/Trainium is Java.

  • NVIDIA = Python. Dynamic typing ≈ dynamic shapes. The runtime absorbs chaos. You throw a 100-token prompt or a 50,000-token prompt at the same forward and it just works, "good enough" fast, with no compile step in your face.
  • TPU/Trainium = Java. Static typing ≈ static shapes. Nothing runs until it's compiled to a fixed binary (NEFF for Neuron, an XLA executable for TPU). In exchange for boilerplate and rigid discipline, you get extreme execution efficiency — once everything fits the contract.

AMD's Instinct line (CDNA, ROCm) sits firmly on the NVIDIA/Python side: SIMT, dynamic shapes, PagedAttention support, and a HIPIFY toolchain whose entire purpose is to run your existing CUDA code unchanged. The static/dynamic split is the real fault line — not the vendor logos.

What "handle dynamic input on static hardware" actually costs you in code

Suppose three users hit your endpoint at once: 3,000 / 4,000 / 1,000 tokens. On NVIDIA you don't pad and you don't build a mask. You concatenate them into one flat 8,000-token buffer and hand FlashAttention a cu_seqlens index marking the boundaries:

# NVIDIA: variable-length attention. No padding, no mask matrix.
# Just a flat buffer + cumulative sequence lengths [0, 3000, 7000, 8000].
outputs = flash_attn_varlen_func(
    q, k, v,
    cu_seqlens_q, cu_seqlens_k,
    max_seqlen_q, max_seqlen_k,
)
Enter fullscreen mode Exit fullscreen mode

The kernel reads the boundary index and isolates each user's context in hardware. No wasted FLOPs on cross-user attention. The code is "just the model logic."

On a TPU you can't reshape the systolic array, so you do the opposite: force everything into one fixed [batch, STATIC_SEQ_LEN] rectangle and use math to erase the parts you don't want computed.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_xla.core.xla_model as xm

class StaticShapeAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads, self.d_k = n_heads, d_model // n_heads
        self.q = nn.Linear(d_model, d_model)
        self.k = nn.Linear(d_model, d_model)
        self.v = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x, attention_mask):
        # x is ALWAYS [batch, STATIC_SEQ_LEN, d_model]. The shape never varies.
        b, s, _ = x.size()
        q = self.q(x).view(b, s, self.n_heads, self.d_k).transpose(1, 2)
        k = self.k(x).view(b, s, self.n_heads, self.d_k).transpose(1, 2)
        v = self.v(x).view(b, s, self.n_heads, self.d_k).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)

        # The systolic array DID compute every cell, including padding and
        # other users' regions. We retroactively delete them: e^(-1e9) -> 0.
        scores = scores.masked_fill(attention_mask == 0, -1e9)
        attn = F.softmax(scores, dim=-1)

        ctx = torch.matmul(attn, v).transpose(1, 2).contiguous().view(b, s, -1)
        return self.out(ctx)
Enter fullscreen mode Exit fullscreen mode

Two things about running this on XLA are pure consequences of static silicon:

  1. xm.mark_step() is the real execution trigger. That import torch_xla at the top isn't decoration. Unlike CUDA's eager mode, calling model(x) on XLA only accumulates a graph. Nothing runs on the chip until mark_step() — called in your serving loop, not inside forward — compiles the accumulated graph into one fixed binary and ships it. New shape → new compile. (Recent PyTorch/XLA adds an eager mode that hides this, but the underlying compile-per-shape model is unchanged.)
  2. masked_fill(..., -1e9) is a hack, not an optimization. NVIDIA's varlen path skips the cross-user multiplications entirely. The systolic array can't skip — it must multiply every cell of the rectangle, including the zeros, and then you mathematically null them out in softmax afterward. You burn the watts, then throw the result away.

The "smallest input" trap

The crash-on-overflow case is intuitive: feed 1,025 tokens into a binary compiled for 1,024 and you get a shape mismatch. The nastier case is underflow — a 100-token request hitting a 1,024 system:

  • Let it through: XLA sees a new shape and triggers a JIT recompile. In production that's a multi-minute freeze. Stall.
  • Pad to 1,024: the array dutifully runs 0 × 0 + 0 across ~90% of its cells, consuming full power to compute nothing. Utilization collapses.

The escape hatch is packing: instead of one user per bucket, tile multiple users' requests into a fixed rectangle like Tetris, and generate a segment-ID mask so attention can't bleed across users.

Fixed bucket [ 8192 tokens ]
├─ User A query (3000)
├─ User B query (4000)
├─ User C query (1000)
└─ padding      (192)   <-- the only waste
Enter fullscreen mode Exit fullscreen mode

It helps to be concrete about what "the rectangle" physically is. When you compile with BATCH_SIZE = 4, STATIC_SEQ_LEN = 8192, XLA reserves one contiguous [4, 8192] static region in the TPU's HBM — not four independent "rooms," but one big sheet the compiler hard-wires the array routes for. A single user rarely fills even one 8,192 lane, so the serving layer packs multiple users across the four lanes at once:

[ One TPU processor: one static [4 x 8192] sheet ]

lane[0] (8192): [ A(2000) + B(5000) + C(1000) + pad(192) ]
lane[1] (8192): [ D(8000)                      + pad(192) ]
lane[2] (8192): [ E(3000) + F(3000) + G(2100)  + pad(92)  ]
lane[3] (8192): [ H(4000) + I(4000)            + pad(192) ]
Enter fullscreen mode Exit fullscreen mode

Physically there are 4 lanes (32K of space); logically the proxy just crammed 9 ragged users (A–I) into them. From the application side it looks like one TPU is concurrently servicing many small requests in parallel — but underneath it's one rigid sheet with a segment mask drawn over it. The reason the hardware wants one fat sheet instead of pre-carved small rooms is pure systolic-array physics: the bigger the matrix, the higher the array's fill rate and the fewer idle cycles between feeds.

Done right, MFU (Model FLOPs Utilization) climbs into the 50–60% band that well-tuned LLM serving actually achieves (PyTorch/XLA reports ~53% training MFU for Llama 2 70B on TPU) — versus the single digits a naive one-user-per-bucket scheme collapses to. 100% is a ceiling nobody touches; the point is that packing recovers most of the loss. But notice what you just built: a high-throughput Go/C++ proxy in front of the cluster whose only job is to catch ragged input and pack it into rectangles in real time. On NVIDIA, that entire layer does not exist.

It's not one function — the whole pipeline forks

People assume torch_xla abstracts the hardware away because xm.xla_device() transparently targets both TPU and Trainium (thanks to the shared OpenXLA/PJRT runtime — libtpu.so for TPU, libneuronpjrt.so for Neuron). That's true for model.to(device) and basic ops. It is emphatically not true for the parts that matter.

The forward signature itself diverges:

# NVIDIA forward: ragged data + boundary index. Length is arbitrary every call.
def forward(self, input_ids, cu_seqlens, max_seqlen):
    return self.flash_attn_func(input_ids, cu_seqlens, max_seqlen)

# Static forward: fixed rectangle + a mask matrix you must build yourself.
def forward(self, input_ids, attention_mask):  # input_ids is [batch, FixedSeqLen]
    return self.static_attn_func(input_ids, attention_mask)
Enter fullscreen mode Exit fullscreen mode

And it cascades all the way down:

Component NVIDIA pipeline Trainium pipeline
Inference engine vLLM (CUDA), TensorRT-LLM NxD / vllm-neuron
Custom kernels Triton, CUDA C++ (FlashAttention) NKI (Neuron Kernel Interface), rewritten from scratch
Base image nvcr.io/nvidia/pytorch AWS Neuron DLC
CI build artifact weights + CUDA/Triton binaries weights + NEFF static binaries per bucket
Deploy target g5 / p5 instances trn1 / inf2 instances
Monitoring nvidia-smi, DCGM exporter neuron-top, Neuron exporter

Two completely parallel worlds. Your CUDA container, your eval scripts, your autoscaling triggers — none of it carries over. vLLM's hardware-plugin mechanism gives you "one skin" at the business-logic layer, but the engine underneath is 100% separate code with separate bugs.

Precision makes it worse

The data-type story isn't symmetric either. BF16 (which Google's TPU pioneered) is stable on both sides — its FP32-range exponent survives the -1e9 mask values without going NaN. But FP8, the current throughput play, favors NVIDIA: FP8 attention scores swing hard and need dynamic scaling at runtime to avoid clipping. A static compiler has to bake in a fixed scale factor at compile time, so on TPU/Trainium aggressive FP8 attention risks clipping that degrades model quality. "Just switch to FP8" is a one-liner on NVIDIA and a research project on static silicon.

The hidden cost: your org chart breaks

This is the part that kills adoption and nobody puts on a slide. On NVIDIA there's a clean abstraction boundary:

[ AI engineer / data scientist ]
   architecture, hyperparams, eval
        │
        ▼  boundary: Hugging Face weights / standard PyTorch
        │
[ MLOps / LLMOps engineer ]
   drop into vLLM, configure PagedAttention, scale out
Enter fullscreen mode Exit fullscreen mode

The data scientist never thinks about memory layout. The MLOps engineer never reads the attention math. They ship artifacts across a clean interface.

On TPU that wall disappears, because model structure is directly coupled to physical constraints:

  • The packing scheme (MLOps) and the segment-mask logic inside forward (AI engineer) are two halves of one design. Change the batching strategy and the math has to change in lockstep. You cannot split that across a spec doc.
  • An AI engineer casually adding an if branch or changing layer count alters the compiled graph topology — and triggers JIT stalls or OOM in production. Debugging that requires dumping the XLA HLO graph, which pulls the AI engineer into an "infra" incident.
  • "BF16 → FP8 for 2x throughput" (MLOps) collides head-on with "FP8 static scaling causes hallucinations on certain tasks" (data scientist). On NVIDIA the runtime negotiates this for you. On TPU the two humans have to negotiate it face to face.

The organizations furthest along on TPU/Trainium — Google's Gemini team (custom silicon end to end), Anthropic's Claude team, and increasingly Meta, which began renting Google TPUs in 2026 to test Llama on both training and inference — lean away from the horizontal "data science dept / infra dept" split entirely. They run a single vertically-integrated team of people fluent in both the attention math and the compiler internals. Most companies cannot staff that, and the projects that try to keep the old division of labor die in a pile of compile errors and OOMs.

So why does anyone use them? Because the input is locked

The whole calculus flips when you control the input channel so the shapes are predictable. Two clean examples:

  • Google / YouTube summaries. The exact internal pipeline isn't public, but the shape is forced by the constraints: Google doesn't re-watch the video. At upload time, an async batch job (on spare TPU cycles) runs ASR and stores timestamped text in storage like Bigtable. When you ask for a summary, the exact text length is already known down to the token — so the router picks a just-right bucket, packing waste is near zero, and a light model like Gemini Flash scans pre-computed text. The "summarize a 2-hour video instantly" magic is really "scan a tiny text index that was built months ago for nearly free."
  • Anthropic / Claude Code. A CLI coding agent has an almost fully determined input: repo structure, tool definitions, git diffs, system prompt. The first ~90% of the context is invariant, which is exactly what static compilation and prompt caching love. Anthropic in fact serves Claude across a mix of Trainium, TPU, and NVIDIA — matching workloads to the most suitable chip — and runs Trainium fleets at scale (neuronx-distributed); a high-throughput Go/C++ packing proxy is the natural front-end for the static path, though Anthropic hasn't published the exact per-product split. Claude Code is — read cynically — close to the perfect input-locking channel that makes a Java-style chip worth the pain. Long-context workloads help too: a 200K-token prefill packs many buckets back-to-back, so the relative padding waste shrinks toward zero — the static array's weakness fades exactly where Claude is strongest.

The inverse is just as logical, and it explains why the chat UIs lean hardest on dynamic SIMT hardware. ChatGPT and Claude.ai's web frontends accept arbitrary text, surprise image uploads, and topic switches mid-conversation. The system can't predict the shape until the user hits send. That chaos is precisely what dynamic SIMT + PagedAttention were built for.

Takeaways

  • TPUs aren't unpopular because they're slow or expensive — they're cheaper per token. They're unpopular because cheapness is conditional on a discipline most teams can't enforce: every tensor shape fixed at compile time.
  • The cost moved, it didn't vanish. Static silicon pushes all the uncertainty out of the hardware and onto your software (packing, masking, bucket routing) and your people (collapsed dev/ops boundary). You trade CapEx (silicon, power) for OpEx (elite engineers maintaining hack layers).
  • The decision rule is about the channel, not the chip. If you own the input — a CLI, a fixed business workflow, your own storage pipeline — TPU/Trainium are a weapon. If your input is a free-form chat box or a third-party API integration, NVIDIA (or AMD) is the only sane choice, and reaching for TPU on EC2-sticker-price alone is how MFU quietly collapses to single digits.

The spec sheet was never lying about cost-per-token. It just wasn't pricing in the engineers, the forked pipeline, and the org redesign you have to buy first.

Top comments (0)