How transformer inference actually works under the hood — and why KV cache is the single most important optimization keeping your LLM from crawling.
If you've ever wondered why LLMs respond fast even on long prompts — the answer is KV cache. But most explanations stop at "it stores keys and values." This goes deeper.
What You'll Learn
By the end of this article you'll understand:
- Why autoregressive LLM generation is expensive by design
- What attention actually computes — and why recomputing it is wasteful
- The difference between prefill and decode phases
- How KV cache grows and when it becomes your GPU's worst enemy
- How vLLM's PagedAttention solved memory fragmentation
- What's coming next: quantized cache, sliding window, speculative decoding
Introduction: Why LLM Inference is Expensive
Let's start with an uncomfortable truth.
When you send a prompt to GPT-4 or Claude and watch that first token appear, your GPU has just burned through millions of floating-point operations before producing a single character. And then, for every subsequent token in the response — it does it again.
Not a small version. The full computation. Attention over the entire sequence. Every time.
Without optimization, a 7B parameter model generating a 200-token response would recompute attention across the full growing sequence 200 times. For a 70B model on a context of 4,096 tokens, that's not slow — it's practically unusable in production.
This is the core economics problem of LLM inference: autoregressive generation is inherently sequential and expensive. You can't parallelize generation across output tokens the way you parallelize training across a batch. Each new token depends on every token that came before it.
KV cache is the engineering solution that makes modern LLM inference economically viable. It's not magic — it's a deliberate memory-compute tradeoff. Understanding it deeply is the difference between an ML engineer who deploys models and one who optimizes them.
What Happens During Token Generation
Before we talk about caching, let's understand what we're caching from.
LLMs generate text one token at a time, left to right. Each generation step:
Takes the full input prompt + all previously generated tokens
Runs a complete forward pass through the transformer
Produces a probability distribution over the vocabulary
Samples the next token from that distribution
Then the new token is appended to the sequence, and the process repeats.
Here's that loop in pseudocode:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
input_ids = tokenizer.encode("The capital of France is", return_tensors="pt")
generated = input_ids
for _ in range(50): # generate 50 tokens
with torch.no_grad():
# Full forward pass every single step — expensive!
outputs = model(generated)
logits = outputs.logits[:, -1, :] # last token's logits
next_token = torch.argmax(logits, dim=-1)
generated = torch.cat([generated, next_token.unsqueeze(0)], dim=-1)
print(tokenizer.decode(generated[0]))
Notice the problem: at step 50, generated has 50+ tokens. We're doing a full transformer forward pass over all of them just to predict token 51. The computation keeps growing linearly with sequence length.
The question is: why do we need to reprocess all previous tokens every time?
Transformer Attention: A Quick Refresher
To understand why KV cache exists, you need to understand what attention actually computes.
The core of transformer inference is multi-head self-attention. For each layer and each token position, attention computes three projections:
Q (Query): "What am I looking for?"
K (Key): "What do I contain?"
V (Value): "What do I actually carry?"
The attention output for position i is:
Attention(Q_i, K, V) = softmax(Q_i · K^T / √d_k) · V
In words: token i asks a question (Q), broadcasts it across all token keys (K) to get attention weights, then uses those weights to take a weighted sum of all values (V).
For a sequence of n tokens, each token attends to all n tokens. This is O(n²) in both time and memory — which is why long contexts hurt so much.
Here's the key insight:
Q changes at every step. But K and V for already-processed tokens do NOT change.
Once a token has been processed by the transformer, its Key and Value projections are fixed. They only depend on the token's content and position — not on future tokens.
This is the mathematical justification for KV cache.
Why Recomputing Attention is Inefficient
Let's make this concrete with numbers.
A standard LLaMA-2 7B model has:
32 transformer layers
32 attention heads
Hidden dimension of 4096
KV head dimension of 128
For a single token, the K and V projections at one layer are each vectors of size 128. Across 32 heads and 32 layers, storing the KV state for one token costs:
2 (K and V) × 32 (layers) × 32 (heads) × 128 (head_dim) × 2 bytes (fp16)
= 524,288 bytes ≈ 0.5 MB per token
Now imagine a prompt of 2,048 tokens:
2,048 tokens × 0.5 MB = 1 GB of KV state
Without caching, every decode step recomputes that entire 1 GB of KV state from scratch. With 200 decode steps, you're recomputing 200 GB of equivalent computation — just to generate a few hundred words.
The recompute path also saturates memory bandwidth. On an A100 (2 TB/s bandwidth), even just reading model weights once per step for a 13B model takes:
13B params × 2 bytes/param = 26 GB per forward pass
26 GB / 2000 GB/s = ~13ms per token → ~77 tokens/sec ceiling
Any redundant recomputation cuts directly into this budget.
What KV Cache Stores
The KV cache stores the Key and Value tensors for all previously processed tokens, so they don't need to be recomputed on subsequent decode steps.
Here's the conceptual layout:
At each decode step, the model:
- Computes Q only for the new token
- Computes K and V for the new token
- Appends new K/V to the cache
- Runs attention using new Q against all cached K/V
- Returns the output logit for the new token
Attention with KV cache:
def attention_with_cache(query_new, key_cache, value_cache, key_new, value_new):
# Append new K/V to cache
keys = torch.cat([key_cache, key_new], dim=1)
values = torch.cat([value_cache, value_new], dim=1)
# New query attends to ALL keys (cached + new)
scores = torch.einsum("bhd,bshd->bhs", query_new, keys) / math.sqrt(d_k)
weights = torch.softmax(scores, dim=-1)
output = torch.einsum("bhs,bshd->bhd", weights, values)
return output, keys, values # return updated cache
Result: instead of O(n) recompute per step, we do O(1) query computation — a constant cost regardless of how long the sequence has grown.
Prefill Phase vs. Decode Phase
Production LLM inference has two fundamentally different operating modes.
Prefill Phase
When you first submit a prompt, the model processes all prompt tokens in parallel:
Prompt: [T_0, T_1, T_2, ..., T_2047]
↓ ↓ ↓ ↓
[All processed simultaneously via batch matrix ops]
↓
[Full KV cache populated for all 2048 positions]
↓
[First output token generated]
Prefill is compute-bound — large matrix multiplications across all positions at once. GPUs excel here.
Prefill time = your Time to First Token (TTFT). Users feel this as the delay before the first character appears.
Decode Phase
After prefill, each new token is generated one at a time:
Step 1: New token T_2048
→ Compute Q, K, V for T_2048 only
→ Append K_2048, V_2048 to cache
→ Attend over positions 0..2048
→ Sample T_2049
Step 2: New token T_2049
→ Cache now has 2050 entries
→ Attend over positions 0..2049
→ Sample T_2050
Decode is memory-bandwidth-bound — constantly reading the growing KV cache from GPU HBM. Compute cores sit largely idle waiting for memory reads.
Decode step memory reads (LLaMA 7B, 2048 cache):
Model weights: ~14 GB
KV cache: ~1 GB (and growing)
Total: ~15 GB
At 2 TB/s bandwidth: ~7.5ms/token → ~133 tokens/sec theoretical max
GPU Memory and KV Cache Growth
Here's where things get expensive fast.
KV cache size =
batch_size × seq_len × num_layers × num_heads × head_dim × 2 × dtype_bytes
Real example — LLaMA-2 13B, batch size 8, 4K context:
batch_size = 8
seq_len = 4096
num_layers = 40
num_kv_heads = 40
head_dim = 128
dtype_bytes = 2 # fp16
kv_cache_bytes = (
batch_size * seq_len * num_layers
* num_kv_heads * head_dim * 2 * dtype_bytes
)
print(f"KV Cache: {kv_cache_bytes / 1e9:.2f} GB")
# KV Cache: 26.84 GB
Now scale to batch 16 at 8K context:
batch_size = 16
seq_len = 8192
# KV Cache: 107.37 GB ← won't fit on a single A100
A100-80GB Memory Budget (LLaMA-2 13B, fp16)
┌─────────────────────────────────────────────┐
│ Model Weights ~26 GB │
├─────────────────────────────────────────────┤
│ KV Cache ~30–40 GB │ ← the battleground
├─────────────────────────────────────────────┤
│ Activations / Other ~5–10 GB │
├─────────────────────────────────────────────┤
│ CUDA Runtime / Misc ~2–3 GB │
└─────────────────────────────────────────────┘
The KV cache is the only dynamic component. It grows as sequences get longer, shrinks as sequences end, and fragments if not managed carefully. Every other component is fixed at load time.
Continuous Batching and PagedAttention
The Problem with Static Batching
Traditional inference servers used static batching: wait for N requests, run until all finish. But requests finish at different times. 7 out of 8 requests finishing at step 50 while one runs to step 500 means your GPU is 87.5% idle on the long tail.
Continuous Batching
Modern engines batch at the iteration level, not the request level:
Time → T0 T1 T2 T3 T4 T5 T6 T7
Slot 0: [A A A A ←done→ E E E ]
Slot 1: [B B ←done→ D D D D ]
Slot 2: [C C C C C C ←done→ F ]
As soon as request A finishes, slot 0 immediately starts request E. GPU utilization stays high.
But this creates a new problem: KV cache memory fragmentation. Different requests have different lengths. You can't pre-allocate contiguous memory blocks without wasting huge amounts.
PagedAttention (vLLM)
vLLM's PagedAttention (2023) solved this with a virtual memory analogy borrowed from OS paging.
Instead of contiguous KV blocks per request, memory is split into fixed-size pages (typically 16 tokens each). A block table maps logical pages to physical GPU memory:
Logical view (Sequence A):
[Page 0 → Page 1 → Page 2 → Page 3]
Physical GPU memory:
Block 7: tokens 0–15 of Sequence A
Block 2: tokens 16–31 of Sequence A
Block 15: tokens 32–47 of Sequence A
Block 3: tokens 48–63 of Sequence A
Block table:
Seq A: { logical 0 → physical 7,
logical 1 → physical 2,
logical 2 → physical 15,
logical 3 → physical 3 }
Benefits:
- No internal fragmentation — allocate one page at a time as
sequence grows
- Prefix sharing — identical system prompt KV pages shared across thousands of requests (copy-on-write)
- Swapping — pages can be evicted to CPU memory under pressure
Memory waste drops from ~30–40% (fragmentation) to under 4%.
# Simplified block table concept
class BlockTable:
def __init__(self, block_size=16, num_blocks=1000):
self.block_size = block_size
self.free_blocks = list(range(num_blocks))
self.block_data = {} # physical_block_id → KV tensor
self.tables = {} # seq_id → {logical → physical}
def allocate_block(self):
return self.free_blocks.pop()
def get_kv(self, seq_id, logical_block_idx):
physical = self.tables[seq_id][logical_block_idx]
return self.block_data[physical]
How vLLM and Inference Engines Optimize KV Cache
vLLM in Practice
from vllm import LLM, SamplingParams
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
tensor_parallel_size=2, # split across 2 GPUs
gpu_memory_utilization=0.90, # leave 10% headroom
max_model_len=4096
)
outputs = llm.generate(
["Explain KV cache to an ML engineer"],
SamplingParams(temperature=0.7, max_tokens=512)
)
vLLM ships: PagedAttention, continuous batching, Flash Attention v2, tensor parallelism, and prefix caching out of the box.
Flash Attention
Flash Attention avoids the expensive O(n²) memory allocation by tiling computations to fit in SRAM:
Standard attention:
Q, K, V → HBM → compute N×N matrix → HBM → output
Memory: O(n²)
Flash Attention:
Q, K, V → SRAM tiles → fused kernel → output
Never materializes full N×N matrix in HBM
Memory: O(n) Speed: 2–4× faster
Prefix Caching
If thousands of users share the same system prompt, their KV caches for that prefix are identical. Compute it once, share it everywhere:
Without prefix caching:
Request 1: compute KV for [SYSTEM: 2048 tokens] + [query A]
Request 2: compute KV for [SYSTEM: 2048 tokens] + [query B]
Request 3: compute KV for [SYSTEM: 2048 tokens] + [query C]
With prefix caching:
Request 1: compute + store KV for [SYSTEM]
Requests 2, 3, ...: load cached KV, compute only [query N]
TTFT reduction: 50–90% on system-prompt-heavy workloads.
KV Cache Challenges in Long-Context Models
Modern models like GPT-4o, Claude 3.5, and Gemini 1.5 Pro support 128K–1M token contexts. This creates KV cache problems at a completely different scale.
Memory at 128K Context
# LLaMA-3 70B with GQA (8 KV heads), 128K context, batch size 1
seq_len = 131072 # 128K tokens
num_layers = 80
num_kv_heads = 8 # Grouped Query Attention
head_dim = 128
dtype_bytes = 2
kv_cache_bytes = seq_len * num_layers * num_kv_heads * head_dim * 2 * dtype_bytes
print(f"KV Cache: {kv_cache_bytes / 1e9:.2f} GB")
# KV Cache: 21.47 GB ← for ONE sequence
4 concurrent requests = 4 × 80 GB H100s just for KV cache.
Decode Bandwidth at 128K
KV cache read per decode step: 21 GB
H100 HBM bandwidth: 3.35 TB/s
Time reading KV cache: 21 / 3350 ≈ 6.3ms per token
→ Maximum ~158 tokens/sec, memory-bandwidth limited
That ceiling gets worse with every new token added.
What Long-Context Models Use
Real-World Production Problems
OOM Crashes Under Load
Without PagedAttention, a single unexpectedly long sequence in a batch can OOM the entire batch. Modern engines handle this with eviction policies and graceful degradation — but this requires careful tuning of gpu_memory_utilization and per-request max_tokens limits.
Prefill-Decode Disaggregation
At scale, teams separate prefill and decode onto different hardware:
- Prefill servers — compute-heavy, large batches, throughput- optimized
- Decode servers — bandwidth-heavy, latency-sensitive, TTFT-optimized
The KV cache is computed on a prefill server then transferred to a decode server over InfiniBand or NVLink. This is the architecture behind services like Anyscale Endpoints and AWS Inferentia2 deployments.
Batching Heterogeneous Request Lengths
Batching a 100-token request with a 4096-token request pads the short one to max length — wasting KV cache memory and compute. Solutions:
- Bucketing — group requests by length ± Δ
- Dynamic padding — pad to nearest power of 2
- Chunked prefill — break long prefills into fixed-length chunks
Cache Eviction Under Memory Pressure
When GPU memory is exhausted, KV pages must be evicted. Common policies:
- LRU — evict sequences not accessed recently
- Priority-based — keep high-SLA requests in GPU, swap background
jobs to CPU
- Recompute on miss — evict aggressively, recompute KV from prompt if needed
Future Optimizations
KV Cache Quantization
Standard KV: FP16 → 2 bytes/element
INT8 KV: INT8 → 1 byte/element (2× memory reduction)
INT4 KV: INT4 → 0.5 bytes/element (4× memory reduction)
Research (KVQuant, KIVI) shows INT8 quantization has minimal perplexity impact. INT4 is viable for values (keys are more sensitive to quantization).
# INT8 KV cache via per-channel quantization
def quantize_kv(tensor, dtype=torch.int8):
scale = tensor.abs().max(dim=-1, keepdim=True).values / 127
quantized = (tensor / scale).to(dtype)
return quantized, scale
def dequantize_kv(quantized, scale):
return quantized.float() * scale
Sliding Window Attention
Standard: token_i attends to tokens [0, i] → O(seq_len) KV cache
Sliding window: token_i attends to tokens [i-W, i] → O(W) KV cache
Mistral 7B uses W=4096 with a rolling buffer — KV cache memory stays constant regardless of sequence length.
Speculative Decoding
Without speculation:
Step 1 → token 1
Step 2 → token 2
...5 steps → 5 tokens
With speculation:
Draft model generates 5 candidate tokens (1 fast pass)
Target model verifies all 5 in 1 parallel pass
Accept ~4–5 tokens in 2 steps instead of 5
→ 2–3× throughput improvement
Implementations: vLLM speculative decoding, Medusa, EAGLE.
Multi-head Latent Attention (MLA)
DeepSeek-V2's MLA caches a low-rank compressed representation instead of full K/V tensors, then decompresses on-the-fly during decode:
Standard KV: Store K [d_k] + V [d_v] per token per layer
MLA: Store c [d_c] (compressed), decompress when needed
d_c << d_k + d_v → up to 5× smaller KV footprint
Final Summary
┌─────────────────────────────────────────────────────────────┐
│ LLM INFERENCE PIPELINE │
│ │
│ INPUT PROMPT │
│ ↓ │
│ ┌─────────────┐ │
│ │ PREFILL │ All prompt tokens processed in parallel │
│ │ PHASE │ Compute-bound → sets your TTFT │
│ └──────┬──────┘ Full KV cache populated │
│ ↓ │
│ ┌─────────────┐ │
│ │ DECODE │ One token per step, autoregressive │
│ │ PHASE │ Memory-bandwidth-bound │
│ │ [NEW TOK] │ Q for new token only │
│ │ │ K/V appended to cache │
│ │ │ Attend over all cached K/V │
│ └──────┬──────┘ → Sample next token │
│ └──→ Repeat until EOS or max_tokens │
└─────────────────────────────────────────────────────────────┘
What KV cache does:
Stores computed Key and Value tensors for all processed tokens, eliminating recomputation at each decode step. Per-step attention drops from O(n) recompute to O(1) compute + cache read.
Why it's a bottleneck:
Grows linearly with sequence length. At 128K+ contexts, consumes tens of GBs per request. Decode becomes memory-bandwidth-limited.
How modern engines solve it:
Where it's going:
LLM inference is ultimately a resource allocation problem. The GPU has a fixed memory budget, a fixed bandwidth, and a fixed compute budget. KV cache sits at the intersection of all three.
The engineers pushing the frontier on inference performance aren't just running models faster — they're rethinking what needs to be stored, what can be shared, what can be compressed, and what can be recomputed.
If you're building production LLM systems, KV cache isn't a detail. It's the design constraint everything else bends around.




Top comments (0)