Attention doesn't actually look at all words. That single insight breaks open the most misunderstood mechanism in modern AI. Every time GPT-4 finishes your sentence, Claude writes code, or Gemini generates an image caption, the same eight-step computation runs billions of times—and most developers have no idea what's happening inside it. This article walks through the exact math, the real implementation tricks, and the one optimization that made today's 200K-token context windows possible.
Key Facts Most People Don't Know
- The original 2017 Transformer used 8 parallel attention heads in each layer, but GPT-3 uses 96 heads per layer with each head operating on only 128 dimensions instead of the full 12,288.
- Scaled dot-product attention divides by the square root of the key dimension (√dk) specifically because without it, dot products grow large in magnitude pushing softmax into regions with extremely small gradients below 0.0001.
- In BERT's implementation, attention masks are added as -10000 before softmax rather than multiplying by zero after, because this prevents numerical instability in the exponential calculation.
Step 1: Project Into Query, Key, and Value Vectors
Every token entering a Transformer layer arrives as an embedding vector of dimension d_model—512 in the original 2017 paper, 12,288 in GPT-3. The first operation multiplies this vector by three separate learned weight matrices: W_Q, W_K, and W_V. These produce the Query (Q), Key (K), and Value (V) vectors.
Think of it like a library. The Query is what you're looking for, the Key is the label on each book, and the Value is the actual content. The attention mechanism compares Queries against Keys to decide which Values matter most.
The projection matrices are learned during training and are different for every layer. In GPT-3's 96-layer architecture, that means 96 × 3 = 288 separate projection matrices, each with 12,288 × 128 parameters per head—totaling billions of parameters just for attention projections.
Step 2: Split Into Multiple Heads
Here's where most explanations stop making sense. Instead of computing attention over the full d_model dimensions, the projections are split into h parallel heads. The original Transformer used 8 heads of 64 dimensions each (512 ÷ 8 = 64). GPT-3 uses 96 heads of 128 dimensions (12,288 ÷ 96 = 128).
Why split? Because a single attention head learns one type of relationship. One head might track subject-verb agreement. Another might track spatial proximity. A third might learn positional patterns. By running h heads in parallel, the model captures h different relationship types simultaneously—and each head only needs to process a small dimension, keeping computation tractable.
"The original 2017 Transformer used 8 parallel attention heads in each layer, but GPT-3 uses 96 heads per layer with each head operating on only 128 dimensions instead of the full 12,288."
The total computation is the same either way (O(N² · d_model)), but multi-head attention gives the model far more representational power per FLOP.
Step 3: Compute the Attention Score Matrix
For each head, the attention scores are computed by multiplying the Query matrix Q with the transpose of the Key matrix K:
Score = Q × K^T
This produces an N × N matrix where N is the sequence length. Cell (i, j) represents how much token i should attend to token j. A high score means token i finds token j relevant; a low score means it doesn't.
This N × N matrix is the reason attention scales quadratically with sequence length. A 2,000-token sequence produces a 2,000 × 2,000 = 4 million–element attention matrix. A 128,000-token sequence produces 16.4 billion elements. This quadratic cost is the fundamental bottleneck of the Transformer architecture—and we'll see how FlashAttention solves it later.
Step 4: Scale by √d_k
The scores are divided element-wise by the square root of the key dimension: √d_k. For d_k = 64, that's √64 = 8. For d_k = 128, it's √128 ≈ 11.31.
This isn't a minor normalization hack—it's critical for training stability. When d_k is large, dot products grow in magnitude proportionally to √d_k. Without scaling, these large values push the softmax function into regions where gradients are vanishingly small (below 0.0001), making training effectively impossible.
Scaled dot-product attention divides by the square root of the key dimension (√d_k) specifically because without it, dot products grow large in magnitude pushing softmax into regions with extremely small gradients below 0.0001. The √d_k scaling was one of the key insights in the original "Attention Is All You Need" paper—without it, deeper Transformers simply don't converge.
Step 5: Apply the Attention Mask
Before softmax, a mask is applied. In a decoder (like GPT), this is a causal mask that prevents token i from attending to any token j > i (the future). In an encoder (like BERT), padding tokens are masked so the model doesn't waste attention on padding.
The implementation detail matters: masks are applied by adding -10,000 (practically -∞) to the masked positions, not by multiplying by zero after softmax. In BERT's implementation, attention masks are added as -10000 before softmax rather than multiplying by zero after, because this prevents numerical instability in the exponential calculation. If you multiplied by zero after softmax, you'd still have small residual values from floating-point rounding. Adding a large negative number before softmax ensures the exponential produces essentially zero, giving you a clean, numerically stable mask.
Step 6: Apply Softmax Row by Row
Softmax converts each row of the (masked, scaled) score matrix into a probability distribution that sums to 1.0:
Attention_weights = softmax(Scaled_Scores)
Each row now tells you exactly how token i distributes its attention across all other tokens. A value of 0.85 in position (i, j) means token i allocates 85% of its attention budget to token j.
Softmax is applied row-wise because each token independently decides where to look. Token 5's attention distribution has no influence on token 12's distribution—they're independent probability distributions.
Step 7: Multiply by the Value Matrix
The softmax probabilities are multiplied by the Value matrix V:
Output = Attention_weights × V
Each output token is a weighted sum of all Value vectors, weighted by the attention probabilities. If token 5 attends 85% to token 3 and 15% to token 8, the output for token 5 is 0.85 × V₃ + 0.15 × V₈.
This is the actual information mixing step. Everything before this—Queries, Keys, scores, softmax—was just computing how much each token should attend to every other token. This step actually moves the information.
Step 8: Concatenate Heads and Project
The outputs from all h heads are concatenated back into d_model dimensions and passed through a final linear projection W_O:
Final_Output = Concat(head_1, ..., head_h) × W_O
This projection mixes information across heads. Without it, each head's output would remain isolated—the model would have h independent "opinions" that never talk to each other. W_O is what lets the model synthesize different attention patterns into a single coherent representation.
The result then passes through a residual connection (adding the original input), layer normalization, and a feedforward network before moving to the next Transformer layer.
The Memory Bottleneck: Why Attention Breaks at Scale
That N × N attention matrix from Step 3 is the Achilles' heel of Transformers. For a 128K-token context window, the attention matrix requires 128,000 × 128,000 × 2 bytes (float16) = 32 GB of GPU memory per layer per head. With 96 heads and 96 layers, that's absurd—far beyond any GPU's capacity.
FlashAttention, introduced in 2022 by Tri Dao, reduces memory usage from O(N²) to O(N) by computing attention in blocks of 128 × 128 tokens and never materializing the full attention matrix in GPU HBM. Instead of loading the entire N × N matrix into high-bandwidth memory, FlashAttention streams blocks through SRAM (on-chip cache), computes partial softmax results, and writes only the final output back to HBM.
FlashAttention 2 (2023) and FlashAttention 3 (2024) pushed this further by parallelizing across sequence length dimensions and overlapping softmax with memory loads. The practical impact: models that previously required 8× A100 GPUs for inference can now run on 2× with FlashAttention enabled.
Multi-Query Attention: The KV Cache Shortcut
Modern models like PaLM (540B), Falcon (180B), and Llama 3 use a variant called multi-query attention (or grouped-query attention in Llama's case). Instead of giving each query head its own key and value head, multi-query attention shares a single KV head across all query heads.
Multi-query attention, used in PaLM and Falcon models, shares a single key-value head across all query heads, reducing the KV cache size by 8-32x while losing less than 1% accuracy on most benchmarks. This is crucial for inference: the KV cache is often the dominant memory cost at long context lengths, and shrinking it by 32× directly translates to serving more users per GPU.
Llama 3 uses grouped-query attention (GQA), a middle ground where 8 query heads share 1 KV head (8-group GQA). This captures most of the memory savings while preserving more quality than pure multi-query attention.
Why This Matters for Developers
Understanding attention internals isn't academic—it directly impacts how you build with LLMs:
- Context window cost scales quadratically. Doubling context more than doubles compute. Use RAG or chunking instead of stuffing 200K tokens.
- KV cache size determines how many concurrent users you can serve. GQA models (Llama 3, Mistral) are cheaper to deploy than MHA models (GPT-3).
- FlashAttention support is a hardware requirement. If your GPU doesn't support it, long-context inference will be 3-8× slower and memory-hungry.
The eight-step attention computation is the engine inside every Transformer. Understanding it means understanding the limits and possibilities of every AI model built today.
But there's a hidden bottleneck that makes attention impossibly slow at scale—until one algorithm changed everything.
Top comments (0)