DEV Community

Tech_Nuggets
Tech_Nuggets

Posted on

Flash Attention: what it does and why it matters

Flash Attention: what it does and why it matters

You have a single H100 with 80 GB of VRAM. The Llama 3.1 70B model fits — barely, at 140 GB in FP16, so you're running at 4-bit quantization and have maybe 5–8 GB of KV cache space left for a long-context workload. The model is fast enough at 8K context, so you push it to 32K for a RAG pipeline. It's still fine. Then you push it to 128K for a document-summary task, and suddenly the attention layer alone is spending 3 seconds per forward pass, 85% of which is just moving data between HBM and SRAM, not doing math. The CUDA kernel occupancy graph tells the story: green compute bars are tiny, grey memory-stall bars are huge. The GPU is bandwidth-bound, and vanilla attention is the cause.

Flash Attention is the algorithm that fixes this by restructuring the attention computation itself — not approximate, not sparse, not quantized, just IO-aware. Here is what it does, how the three versions differ, and where it stops helping.

Why this matters in practice

The attention mechanism is the core of every transformer: compute a similarity matrix S = Q K^T, normalize it with softmax P = softmax(S), and use it as weights over values O = P V. The problem is that for sequence length N and head dimension d, the S and P matrices are N×N, and writing them to GPU HBM (high-bandwidth memory) and reading them back is the bottleneck, not the matrix multiplies themselves.

For N = 32K and d = 128 (a single GPT-style head), S is 1 GB. At HBM bandwidth of 2 TB/s on an H100, moving that matrix out and back costs ~1 ms per layer. Across 80 layers and both forward and backward passes, that adds up to 150+ ms per step, and you haven't done a single useful ALU operation yet — just memory shuffling. At 128K context, the per-layer HBM traffic for vanilla attention hits ~16 GB, and the memory wall dominates.

Flash Attention eliminates almost all of the intermediate HBM traffic by tiling the Q, K, V matrices into blocks that fit in on-chip SRAM (192 KB on A100, 256 KB on H100), performing the entire softmax + weighted sum inside SRAM, and only writing the final output O back to HBM. The result: 2–4× faster attention for typical long-context workloads, up to 10× for very long sequences, with bit-exact output for FP16/BF16 and tiny relative error in FP8.

How the algorithm works

The core insight is that softmax over a sub-block can be recomputed from the running statistics. You don't need the full N×N matrix — you can process Q, K, V in blocks, compute local softmax within each block, maintain an online estimate of the softmax denominator, and merge the results.

flowchart LR
    subgraph HBM["HBM (main memory)"]
        Q["Q (N × d)"]
        K["K (N × d)"]
        V["V (N × d)"]
        O["O (N × d)"]
    end
    subgraph SRAM["SRAM (on-chip, ~192 KB)"]
        Qi["Q_block (Bc × d)"]
        Kj["K_block (Br × d)"]
        Vj["V_block (Br × d)"]
        Sij["S_block (Bc × Br)"]
        Pij["P_block (Bc × Br)"]
        Oi["O_block accumulator"]
        mi["Row max<br/>m_i"]
        li["Row sum<br/>ℓ_i"]
    end
    Q -->|tile| Qi
    K -->|tile| Kj
    V -->|tile| Vj
    Qi --> Sij
    Kj --> Sij
    Sij --> Pij
    Pij --> Oi
    Oi -.->|write| O
Enter fullscreen mode Exit fullscreen mode

The algorithm for each attention head proceeds as follows:

  1. Divide Q into blocks of size Bc that fit in SRAM alongside one block each of K and V.
  2. Divide K and V into blocks of size Br.
  3. For each Q block i and each K/V block j:
    • Load Q_i and K_j, V_j into SRAM.
    • Compute S_ij = Q_i K_j^T in SRAM.
    • Compute local softmax: m_ij = rowmax(S_ij), P_ij = exp(S_ij - m_ij), ℓ_ij = rowsum(P_ij).
    • Update global running max m_i = max(m_i, m_ij).
    • Update global running sum ℓ_i = exp(m_i_prev - m_i) · ℓ_i + exp(m_ij - m_i) · ℓ_ij.
    • Correct and accumulate output: O_i = O_i · exp(m_i_prev - m_i) / (ℓ_i / ℓ_i_prev) + (P_ij V_j) / ℓ_i.
  4. Write the final O_i back to HBM after all K/V blocks have been processed.

The critical property: the output is identical to vanilla attention in FP16/BF16, because softmax over the full sequence is exactly reconstructed from the block-level statistics. The algorithm does not approximate — it rearranges.

Flash Attention 1 → 2 → 3

Feature Vanilla Flash Attn v1 Flash Attn v2 Flash Attn v3
Paper N/A Dao et al., 2022 Dao et al., 2023 Shah + Dao, 2025
GPU target Any A100 (Ampere) A100 + H100 H100/H200 (Hopper)
HBM traffic per step O(N² d) O(N² d / M) same same
Forward speed vs vanilla 2–3× 3–4× 4–6×
Backward speed vs vanilla 2–3× 4–5× 6–8×
Precision FP32/BF16 FP16/BF16 FP16/BF16 FP8/BF16/FP16
Data type standard FP16 only BF16 + FP16 FP8 + BF16 + FP16
Core technique none Tiling + recompute Improved block scheduling Async WGMMA + FP8
CUDA features used standard MMA (Tensor Core) MMA + better occupancy WGMMA + async copy
Open source ✓ (Dao-AILab) ✓ (Dao-AILab) ✓ (Dao-AILab)

Flash Attention v1 (NeurIPS 2022, the paper that started it): Introduced the tiling scheme, proved the IO complexity result (O(N² d / M) HBM accesses vs O(N² d) for vanilla), and showed that the algorithm is exact for FP16. Forward pass is 2–3× faster than PyTorch's scaled_dot_product_attention on A100s. The backward pass uses the same tiling approach but recomputes S and P from the stored Q, K, V tiles rather than materializing the full gradient matrices.

Flash Attention v2 (2023): Redesigned the work distribution. In v1, each thread block processes one Q-block and iterates over all K/V blocks (SPMD-style). In v2, the parallelism is over different Q-blocks independently, and within each block the softmax reduction is fused with the output accumulation. This halves the number of global atomics and improves occupancy. v2 is roughly 2× faster than v1 on both A100 and H100, and it's the version that made Flash Attention a default in Hugging Face Transformers and PyTorch 2.x.

Flash Attention v3 (2024–2025, Hopper-specific): Taps the H100's WGMMA (warp-group matrix multiply-accumulate) instructions and asynchronous TMA (tensor memory accelerator) copies. v3 overlaps SRAM data transfers with computation via async copies: while the current block is computing attention, the next block's K, V tiles are being fetched in the background. The FP8 path uses the H100's 2× faster FP8 Tensor Cores (1.97 PFLOPS vs 989 TFLOPS for FP16) with stochastic rounding. v3 delivers 4–6× speedup over vanilla attention and is the recommended default for Hopper GPUs with sequence lengths above 8K.

Using it in practice

Flash Attention 3 is included in the flash-attn PyPI package (v3.1.2 as of May 2026). Installation is a single line:

pip install flash-attn
Enter fullscreen mode Exit fullscreen mode

The API is straightforward once the package is installed. The main entry points are functions, not a module that auto-patches your model:

import torch
from flash_attn import flash_attn_func

q = torch.randn(1, 32, 4096, 128, dtype=torch.bfloat16, device="cuda")
k = torch.randn(1, 32, 4096, 128, dtype=torch.bfloat16, device="cuda")
v = torch.randn(1, 32, 4096, 128, dtype=torch.bfloat16, device="cuda")

# (batch, heads, seqlen, headdim) → (batch, seqlen, heads, headdim)
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()

out = flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True)
# out shape: (1, 4096, 32, 128) — same as input layout
Enter fullscreen mode Exit fullscreen mode

For most users, the easiest path is PyTorch's torch.nn.functional.scaled_dot_product_attention, which detects Flash Attention through the torch.backends.cuda.sdp_kernel context manager and dispatches to it automatically when the input dtype, layout, and GPU support it:

torch.backends.cuda.enable_flash_sdp(True)  # on by default in PyTorch 2.x
with torch.backends.cuda.sdp_kernel(
    enable_flash=True, enable_math=False, enable_mem_efficient=False
):
    out = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
Enter fullscreen mode Exit fullscreen mode

The dispatch check is reliable on A100 and H100 with BF16/FP16 inputs and head dimensions of 64 or 128. For FP8, you need H100 and flash_attn_func directly.

FA3 also integrates with Hugging Face models via attn_implementation="flash_attention_2" in from_pretrained:

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)
Enter fullscreen mode Exit fullscreen mode

This swaps the attention module during model loading and is the path most training pipelines use today.

Common pitfalls

  • Head dimension must be 64 or 128 (v1/v2) or up to 256 (v3). This is a hardware constraint from Tensor Core layout requirements. Models with unusual head dims (e.g., 80 in some older architectures) will silently fall back to vanilla attention with no error message.
  • FP8 has higher numerical error on outlier-heavy models. Flash Attention 3's FP8 path pre-scales K and V row-wise and accumulates in FP16, but extremely spiky attention patterns (e.g., models trained without attention dropout) can amplify the relative error. Compare the output distribution on a few samples before trusting FP8 for your use case.
  • Not all GPUs support all versions. FA1 needs A100-class Tensor Cores (it won't run on V100). FA2 runs on Ampere and newer. FA3 requires Hopper (H100/H200) — SM 90 kernels will not load on Ada Lovelace.
  • Memory gains are less visible with very short sequences. At N < 512, the overhead of block iteration and the SRAM management cost can make Flash Attention slower than a well-tuned vanilla kernel. PyTorch's sdp_kernel handles this by falling back automatically, but if you call flash_attn_func directly at short context, benchmark first.
  • Dropout in attention is not free. FA supports attention dropout via a separate random mask, but because it recomputes S and P in the backward pass, the dropout rng state must be stored per block. In practice, most modern LLMs don't use attention dropout, so this rarely matters.

When NOT to use it

Flash Attention is the wrong tool if:

  • Your GPU is compute-bound, not memory-bound. On very small batch sizes with short contexts, the attention operation's HBM traffic is small enough that the GPU's Tensor Cores are the bottleneck, not the memory system. Flash Attention's tiling adds per-block overhead that can regress performance at N < 512 on high-end GPUs.
  • You need exact FP32 attention for research or numerical experiments. Flash Attention is exact for FP16/BF16 (bitwise identical to the unfused computation), but in FP32 it would be slower than vanilla because the tiling overhead is not amortized. For most LLM work this doesn't matter — BF16 is the training standard — but it's worth flagging.
  • Your model uses an unusual attention variant. ALiBi, xPos, linear attention (Mamba-style), and sliding-window attention have their own fused kernels that may not compose with Flash Attention's tiling. Flash Attention works for standard softmax attention with optional causal masking and ALiBi, but not for every recent variant.
  • You're on a production inference stack that already uses prefix caching. Flash Attention and prefix caching both sit in the attention layer, and they compose — but only if your serving engine (vLLM / SGLang) has implemented the combined kernel. As of v0.22, vLLM does not fuse FA3 with its prefix-caching kernel. You get one or the other, not both simultaneously (though this is a known work-in-progress).

TL;DR

  • Flash Attention tiles the Q, K, V matrices into SRAM-sized blocks, computes softmax on each block, and merges the results using online statistics. The output is bit-exact in FP16/BF16 — not approximate.
  • Original insight: standard attention is HBM-bandwidth-bound, not compute-bound. Reducing HBM round-trips from O(N² d) to O(N² d / M) is where the speedup comes from.
  • v1 (NeurIPS 2022) proved the concept on A100s. v2 (2023) doubled performance with better parallelism. v3 (2025) adds FP8 and async copies, reaching 4–6× vs vanilla on H100s.
  • Use it through PyTorch 2.x scaled_dot_product_attention (auto-dispatch) or Hugging Face attn_implementation="flash_attention_2" for the easiest path.
  • Skip it for sequences under 512 tokens, FP32 research, or unusual attention variants that don't use standard softmax.

Next post: Mixture of Experts (MoE) — what practitioners need to know about routing, load balancing, and the engineering decisions behind Mixtral and DeepSeek-V3.

Top comments (0)