DEV Community

jidonglab
jidonglab

Posted on

Grouped-Query Attention: The KV Cache Math Behind Long Context

Take a 70B-class transformer with 80 layers and 64 attention heads. At 128K tokens of context, the multi-head version of that model wants roughly 320 GB just for its KV cache — more than four H100s hold, before you load a single weight. Switch the same model to Grouped-Query Attention (GQA) with 8 key/value heads and that number drops to 40 GB. Same parameter count, same quality on most tasks, one-eighth the cache. That ratio is the single most important number in long-context serving, and most people who use these models have never seen the formula that produces it.

This post is the formula, the trade-off, and what it does to your batch size.

TL;DR

  • The KV cache stores per-token key and value vectors so attention never recomputes them; its size grows linearly with sequence length and dominates memory at long context, not the weights.
  • Grouped-Query Attention (GQA) shares one K/V head across a group of query heads, cutting cache size by the group factor — typically 4x to 8x — with little quality loss.
  • The exact size is 2 · n_layers · n_kv_heads · head_dim · seq_len · batch · dtype_bytes. Only n_kv_heads changes between MHA, GQA, and MQA.
  • A smaller KV cache directly buys larger batches, longer context, or both — it is the lever behind most "we support 128K/200K context" claims.
  • Stack GQA with KV-cache quantization (fp8/int8) for another ~2x, but watch accuracy on long-range retrieval.

Why does the KV cache exist at all?

Autoregressive decoding generates one token at a time, and every new token attends to all previous tokens. The keys and values for those previous tokens never change once computed. Recomputing them at every step would make generation quadratic in sequence length. So you cache them.

For each token, at each layer, you store its key vector and its value vector. On the next step, the new query attends against the whole stored stack. That cache is the KV cache, and it is pure inference-time state — it does not exist during a single forward pass of training the way it persists during decoding.

The cost: the cache grows by one slot per token, per layer, per KV head. At 128K tokens this state eclipses the model weights. Long-context inference is a memory-bandwidth and memory-capacity problem, and the KV cache is the thing filling that memory.

What is the exact KV cache size formula?

The size in bytes is:

kv_bytes = 2 · n_layers · n_kv_heads · head_dim · seq_len · batch · dtype_bytes
Enter fullscreen mode Exit fullscreen mode

The leading 2 is for K and V. n_kv_heads is the number of key/value heads — this is the only term GQA changes. Here it is as runnable code:

def kv_cache_bytes(n_layers, n_kv_heads, head_dim,
                   seq_len, batch=1, dtype_bytes=2):
    # dtype_bytes: fp16/bf16 = 2, fp8/int8 = 1
    return 2 * n_layers * n_kv_heads * head_dim * seq_len * batch * dtype_bytes

GB = 1024**3

# A 70B-class config (Llama-2-70B-shaped): 80 layers,
# 64 query heads, head_dim 128. MHA would use 64 KV heads.
cfg = dict(n_layers=80, head_dim=128, seq_len=128_000, dtype_bytes=2)

mha = kv_cache_bytes(n_kv_heads=64, **cfg)   # multi-head
gqa = kv_cache_bytes(n_kv_heads=8,  **cfg)   # grouped, group=8
mqa = kv_cache_bytes(n_kv_heads=1,  **cfg)   # multi-query

print(f"MHA: {mha/GB:6.1f} GB")   # ~320 GB
print(f"GQA: {gqa/GB:6.1f} GB")   # ~ 40 GB
print(f"MQA: {mqa/GB:6.1f} GB")   # ~  5 GB
Enter fullscreen mode Exit fullscreen mode

The per-token cost falls straight out: divide by seq_len. For the MHA config that is about 2.5 MB per token of context. With GQA-8 it is ~320 KB per token. That per-token number is the one to internalize, because you multiply it by context length and batch size to size a deployment.

What does Grouped-Query Attention actually change?

GQA changes how many distinct key/value projections exist. In standard Multi-Head Attention (MHA), every one of the H query heads has its own K and V head — n_kv_heads = H. In Multi-Query Attention (MQA), all query heads share a single K/V head — n_kv_heads = 1. GQA is the middle ground: query heads are split into G groups, and each group shares one K/V head — n_kv_heads = G.

The queries stay full-width. You still have 64 query heads computing 64 distinct attention patterns. What shrinks is the cached state, because 8 query heads now read from the same K/V vectors instead of each carrying their own.

That asymmetry is why GQA works as well as it does. The expressive part of attention — what each query looks for — is untouched. What you share is the representation of the past, and it turns out a group of related query heads can attend against a common key/value basis with minimal loss. MQA pushes this to the extreme of one shared head and tends to lose more quality and train less stably; GQA-8 recovers nearly all of MHA's quality at a fraction of the cache. That is why GQA, not MQA, is the default in current open and frontier-scale architectures.

Why does the KV cache decide your max batch size?

Because on a fixed GPU, weights are fixed cost and the KV cache is per-request cost. Your memory budget is roughly:

VRAM = weights + (kv_per_token · context · batch) + activations + overhead
Enter fullscreen mode Exit fullscreen mode

Weights and overhead are constant. Every concurrent request adds its full KV cache. So your maximum batch — and therefore your throughput — is set by how much memory is left after weights, divided by the per-request cache.

Make that concrete. Suppose 120 GB is free for cache after loading weights on a node. At 32K context:

  • MHA at ~2.5 MB/token → ~80 GB per request → you fit one request.
  • GQA-8 at ~320 KB/token → ~10 GB per request → you fit ~12 requests.

Twelve times the batch from one architectural choice. Higher batch means the expensive weight reads are amortized across more tokens per memory load, which is exactly where throughput on memory-bound decoding comes from. GQA does not just save memory; it converts that memory into concurrency.

Where does the cache memory actually go — and why fragmentation matters

Allocating one contiguous block per request for the maximum possible length wastes enormous memory, because most requests are shorter and you can't predict final length. This is the problem PagedAttention (the idea behind vLLM) solves: it splits the KV cache into fixed-size blocks and maps them through a block table, like virtual memory pages. Requests grow block by block, blocks are shared across sequences (e.g. a common prompt prefix), and internal fragmentation drops from ~60–80% in naive allocators to a few percent.

So real-world KV capacity is the formula above times an allocator-efficiency factor. A good paged allocator is the difference between the theoretical batch size and the one you actually get. If you serve long context and your effective batch is far below what kv_cache_bytes predicts, fragmentation — not the formula — is usually the culprit.

How does GQA stack with KV-cache quantization?

GQA and quantization are orthogonal levers, and they multiply. GQA cuts n_kv_heads; KV-cache quantization cuts dtype_bytes. Storing K/V in fp8 or int8 instead of bf16 halves the cache again:

gqa_fp8 = kv_cache_bytes(n_kv_heads=8, dtype_bytes=1, **cfg)
print(f"GQA + fp8 KV: {gqa_fp8/GB:.1f} GB")  # ~20 GB
Enter fullscreen mode Exit fullscreen mode

The caveat: KV quantization is not free on long-range tasks. Keys are more sensitive than values, and aggressive int4 KV quant shows up first as degraded needle-in-a-haystack retrieval and long-context reasoning, even when short-context benchmarks look fine. The practical recipe: GQA always, fp8 KV cache when you need the headroom and have measured retrieval at your target context length, and reserve int4 KV for cases where you've confirmed the task tolerates it. Measure on long-context evals, not 2K-token sanity checks — that's where the loss hides.

When would you not want GQA?

GQA is a near-universal win for inference, but it is an architectural decision baked in at pretraining — you can't bolt it onto an MHA model without conversion and fine-tuning. A few honest caveats:

  • Very short context, tiny batch, latency-bound single-stream workloads barely touch the KV cache, so GQA's memory win is muted there. The reason to still use it is throughput under load, not single-request latency.
  • Group size is a real knob. GQA-2 keeps more quality and saves less; GQA-8 saves more and risks slightly more loss on attention-heavy tasks. The sweet spot is empirical and model-dependent.
  • For models you don't train, you simply inherit whatever the architects chose. The value of the formula then is diagnostic: it tells you why your context limit and batch ceiling are where they are, and whether quantization can move them.

The one number to remember

Per-token KV cost — 2 · n_layers · n_kv_heads · head_dim · dtype_bytes — is the quantity that governs long-context serving. Memorize your model's value in KB/token and you can estimate, in your head, the max context and batch any GPU will hold.

So what is Grouped-Query Attention doing for long context?

Grouped-Query Attention shrinks the KV cache — the per-token key/value state that decoding caches to avoid recomputation — by sharing one key/value head across a group of query heads, typically cutting cache size 4x to 8x with minimal quality loss. Because the KV cache, not the weights, dominates memory at long context, that reduction is what lets a model serve 128K-token windows and large concurrent batches on a fixed GPU. The size is exactly 2 · n_layers · n_kv_heads · head_dim · seq_len · batch · dtype_bytes, and GQA touches only n_kv_heads. Stack it with fp8 KV-cache quantization for another ~2x, validate on long-context retrieval, and you have the full memory budget for production long-context inference.

Top comments (0)