If you have ever tried to push a transformer to a longer context and watched your GPU run out of memory, you have met the real bottleneck in attention. It is not the number of multiplications. It is the giant matrix that attention wants to write to memory. Flash Attention is the trick that makes that matrix disappear while computing the exact same answer.
Where the memory goes
Self-attention is one line: for queries Q, keys K and values V, the output is softmax(QKᵀ / √d) · V. Every query scores itself against every key, those scores become weights through a softmax, and the weights blend the values.
The problem hides in QKᵀ. It has one entry for every pair of tokens, so its shape is N × N for a sequence of length N. That is quadratic. Go from 1,000 tokens to 2,000 and the score matrix does not double, it quadruples. At 8,192 tokens you are looking at roughly 67 million numbers, about 128 MB, and that is per attention head, per layer. The scores get written to memory, read back for the softmax, written again as weights, and read once more to multiply by V. For a big model that traffic is enormous, and it grows faster than the model itself.
Why moving data is the real cost
Here is the part most people skip. A GPU has two kinds of memory. On-chip SRAM is tiny, a few tens of KB per unit, but ridiculously fast. HBM is the big pool of "GPU RAM," tens of gigabytes, but roughly ten times slower to touch. Attention does relatively little arithmetic for each byte it shuffles around, which means it spends most of its time waiting on HBM rather than doing math. In systems language, attention is memory-bound.
That single fact changes the whole optimization target. If you are memory-bound, doing fewer multiplications barely helps. What helps is moving less data. So the goal becomes: never write that N × N matrix to slow memory at all.
Tiling plus a streaming softmax
Flash Attention splits Q, K and V into small blocks that fit inside fast SRAM. It loops over query blocks on the outside, and for each one it streams over the key/value blocks on the inside. At any instant only a small query tile, a small key/value tile, and one block of scores are on chip. The full matrix is consumed block by block and thrown away. It is never assembled.
The obstacle is the softmax. Normally softmax needs the whole row at once, because you subtract the row maximum for numerical safety, exponentiate, then divide by the total. If the scores arrive a block at a time, you cannot see the full row. The fix is the online softmax: carry a running maximum m and a running sum l, and update them as each block arrives.
When a new block reveals a larger maximum than you had seen, every earlier exponential was computed against a smaller max and is now scaled wrong. So you multiply the running sum and the running output accumulator by one correction factor, exp(m_old − m_new), before adding the new block's contribution. Because the numerator and the denominator get the same correction, their final ratio is exact. That cheap rescale, applied once per block, is the heart of the whole algorithm.
The backward pass gets the same treatment. Instead of storing the N × N matrix for gradients, Flash Attention saves only the tiny per-row statistics and recomputes each score tile on the fly during backprop. Recomputation costs a few extra multiplies, but since the operation is memory-bound and those multiplies are close to free, it is a clear win. Training memory stays linear.
Exact, not approximate
This is the point worth repeating. Flash Attention is not sparse attention, not a low-rank shortcut, not a lossy approximation. It computes the identical mathematical function as standard attention, just in an order that touches slow memory far less. The output matches the naive version to floating-point dust. I put a small numeric check in the interactive page: naive full-matrix attention versus tiled online-softmax attention on the same random inputs, and the largest difference lands around 1e-16.
So the memory drops from O(N²) to O(N), the reads and writes to HBM collapse, and the wall-clock time falls hard because the thing was memory-bound to begin with. The number of FLOPs barely changes.
Why it mattered
Making attention linear in memory is a big reason context windows jumped from a couple thousand tokens to tens and hundreds of thousands. Training got cheaper on the same hardware, long-prompt inference became practical, and because the change is exact and drop-in, it was adopted almost everywhere. FlashAttention-2 improved the parallelism for roughly double the throughput, and FlashAttention-3 targets newer hardware with asynchronous execution and low-precision paths, all preserving the same exactness guarantee.
In practice you rarely write the kernel. In PyTorch, F.scaled_dot_product_attention dispatches to it automatically. In Hugging Face, you pass attn_implementation="flash_attention_2". One flag, same answer, a lot less memory.
Play with the tiling, the online softmax, and the exactness check here: https://dev48v.infy.uk/ai/days/day25-flash-attention.html
Top comments (0)