DEV Community

Tech_Nuggets
Tech_Nuggets

Posted on

Flash Attention: what it does and why it matters

Flash Attention: what it does and why it matters

Your training job is paying for an A100 at $3/hour. The loss is going down, gradients are flowing, and the model's loss curve looks textbook-logarithmic. But if you profile the step time and look at what the GPU is actually doing, you'll see something alarming: the GPU compute units are idle 40-60% of the time. The bottleneck isn't arithmetic -- it's memory bandwidth. The GPU's HBM (high-bandwidth memory, 1.5-2 TB/s on an A100) cannot keep up with how fast the compute units want to consume data. And the single biggest chunk of memory traffic in any transformer training or inference run is the attention computation, which naively reads and writes the full N x N attention matrix to HBM for every forward pass.

Flash Attention exists to solve that one problem: it eliminates the redundant HBM traffic by fusing the attention computation into tiles that stay entirely inside the GPU's SRAM (the fast, on-chip memory, roughly 20 MB on an A100). The result is a 2-4x end-to-end speedup on attention-bound workloads, at zero loss of precision, with no model changes required.

Why attention memory costs matter

A standard self-attention layer on a single head works with three matrices Q, K, V, each of shape (N, d) where N is the sequence length and d is the head dimension. The naive computation:

  1. Compute S = Q @ K^T -- shape (N, N)
  2. Compute P = softmax(S, dim=-1) -- shape (N, N)
  3. Compute O = P @ V -- shape (N, d)

The critical cost is that S and P are each N x N entries. For a 4096-token sequence with d=128, that's 16 million entries per head. At FP16, that's 32 MB per head. With 32 heads, the full N x N matrix across all heads would be 1 GB -- far larger than the ~20 MB of SRAM on a single A100 GPU. The standard implementation writes this 1 GB to HBM (slow), reads it back for softmax (HBM read), writes the result back (HBM write), then reads it again for the V multiplication.

Flash Attention avoids materializing this N x N matrix entirely by tiling the softmax computation across blocks small enough to fit in SRAM.

What Flash Attention actually does

The core insight from Tri Dao and the Stanford group (2022) was that the attention computation is IO-bound, not compute-bound, and the dominant cost is moving data between HBM and SRAM. On an A100, SRAM bandwidth is roughly 20 TB/s (compute units to SRAM), while HBM bandwidth is ~2 TB/s. A 10x difference. If the computation can be structured to stay in SRAM, it wins.

The mechanism is algorithmically straightforward:

  1. Block the Q, K, V matrices into tiles small enough to fit in SRAM.
  2. Compute a partial softmax for each block, using the online softmax algorithm (safe softmax that can be updated incrementally).
  3. Accumulate partial results into the output, keeping per-block rescaling statistics in registers.
  4. Write the final output to HBM once per layer, instead of multiple reads/writes per head.

This is a classic tiling technique, but applied to the attention-specific problem where the softmax is a global normalization -- you cannot naively sum over tiles because softmax requires a denominator over the full row. The paper's key algorithmic contribution is an online-safe softmax that lets each tile compute a local softmax and then correct the running output as new tiles arrive.

# Pseudocode for one Flash Attention forward pass block
def flash_attention_block(Q_block, K_block, V_block):
    # Q_block: (B_r, d), K_block: (B_c, d), V_block: (B_c, d)
    # B_r and B_c are tile sizes chosen to fit in SRAM

    # Initialize running maximum and denominator
    m = -inf   # row-wise max for numerical stability
    l = 0.0    # sum of exp(x - m) for the running normalization
    O = zeros(B_r, d)

    for each K, V tile:
        S = Q_block @ K_tile.T        # local attention scores (B_r, B_c)
        m_new = max(m, rowmax(S))     # update running max
        l_new = exp(m - m_new) * l + rowsum(exp(S - m_new))
        P = exp(S - m_new) / l_new    # local softmax
        O = (l * exp(m - m_new) / l_new) * O + P @ V_tile
        m, l = m_new, l_new

    return O
Enter fullscreen mode Exit fullscreen mode

The algorithm reads Q, K, V from HBM once, processes them tile by tile in SRAM, and writes O to HBM once. Compare to the naive approach: for a sequence of length N, the standard implementation reads and writes the N x N attention matrix to HBM, which is O(N^2 d) HBM traffic. Flash Attention reduces this to O(N^2 d / M) where M is the SRAM size -- a reduction proportional to SRAM capacity.

The following diagram shows how the tiling skips the materialization of the full attention matrix:

flowchart TB
    subgraph SRAM["GPU SRAM (~20 MB)"]
        QB[Q tile<br/>(B_r x d)]
        KB[K tile<br/>(B_c x d)]
        VB[V tile<br/>(B_c x d)]
        ST[Partial S = QB @ KB^T<br/>(B_r x B_c)]
        OT[Partial O accumulator<br/>(B_r x d)]
    end
    subgraph HBM["GPU HBM (~40-80 GB)"]
        QF[Full Q<br/>(N x d)]
        KF[Full K<br/>(N x d)]
        VF[Full V<br/>(N x d)]
        OF[Full O<br/>(N x d)]
    end

    QF -->|read once| QB
    KF -->|read once<br/>tile by tile| KB
    VF -->|read once<br/>tile by tile| VB
    KB --> ST
    VB -->|partial products| OT
    OT -->|write once| OF

    style SRAM fill:#1e293b,stroke:#38bdf8,color:#e2e8f0
    style HBM fill:#0f172a,stroke:#64748b,color:#94a3b8
Enter fullscreen mode Exit fullscreen mode

Each arrow from HBM to SRAM is a slow DMA transfer. The naive implementation makes O(N) of these per row and per head. Flash Attention makes exactly two passes over K and V (read and tile-by-tile process), then writes O once.

Flash Attention v1 vs v2 vs v3

Version Year Key improvements Speedup vs naive GPU focus
v1 2022 Tiling + online softmax, O(N^2) avoidance 2x A100 (Ampere)
v2 2023 Reduced non-matmul ops, better parallelism, non-power-of-2 lengths supported 2-3.5x A100, H100
v3 2024-2025 WGMMA (warp-group matrix multiply-accumulate) for H100 Tensor Cores, async pipelining, FP8 support 3-7x H100/B200 (Hopper)

Flash Attention v2 removed a significant number of non-matrix-multiply instructions that creation of the mask and scaling required. This matters because Tensor Cores are most efficient when the workload is pure matrix multiplication, and any extra elementwise operations reduce utilization. The v2 paper reported that a single forward pass on a 65M-parameter model went from 6.5ms (PyTorch standard) to 2.6ms (Flash Attention v2).

Flash Attention v3, published in 2024, targets the H100's Hopper architecture. It uses the WGMMA instruction (warp-group MMA), which lets the GPU overlap data movement with computation during the tiled softmax pass. The synchronous SRAM reads of v1/v2 are replaced with asynchronous copies that hide latency. Additionally, v3 introduces FP8 support that cuts data movement in half again for the score computation.

Where Flash Attention is used today

Flash Attention is integrated into virtually every major LLM framework. The most common path is through PyTorch's scaled_dot_product_attention (SDPA), which has shipped the flash-attention backend since PyTorch 2.0:

import torch.nn.functional as F

# This automatically uses Flash Attention if conditions are met:
# - CUDA GPU
# - dtype is half-precision (FP16 or BF16)
# - head_dim is a multiple of 8
# - (v2+) Sequence length doesn't have restrictions on being power of 2
attn_output = F.scaled_dot_product_attention(
    query, key, value,
    attn_mask=None,
    dropout_p=0.0,
    is_causal=True
)
Enter fullscreen mode Exit fullscreen mode

You don't need to import flash_attn directly in most cases. PyTorch's SDPA dispatches automatically to the best available backend: Flash Attention if available, otherwise memory-efficient attention, and falls back to the naive implementation.

For direct access, the flash-attn package on PyPI provides the FlashAttention module:

pip install flash-attn
Enter fullscreen mode Exit fullscreen mode

This installs a prebuilt wheel matching your CUDA and PyTorch combination (PyPI wheels are available starting with v2.8.x). If no wheel exists for your configuration, building from source takes about 15 minutes and requires a CUDA compiler.

from flash_attn import flash_attn_func

output = flash_attn_func(
    q, k, v,
    dropout_p=0.0,
    softmax_scale=scale,
    causal=True
)
Enter fullscreen mode Exit fullscreen mode

The flash_attn_func API gives you direct control over the backend parameters and is the path used by vLLM, Hugging Face transformers, and torch.compile paths.

Common pitfalls

The is_causal / padding interaction. If you use a causal mask AND a separate padding mask (for batched sequences of different lengths), the interaction between them is non-trivial. Flash Attention should handle it, but passing attn_mask with both a causal mask and individual padding requires careful construction. The safest approach is to leave causal=True and pad to the same length, or use a per-batch mask that is the full N x N with -inf in the right places.

Head dimension limits. Flash Attention has historically had constraints on head dimension. v1 required head_dim <= 128. v2 increased this to head_dim <= 256. v3 supports up to 256. If your model uses head_dim=96 or head_dim=64, you are fine. If you are experimenting with head_dim=512 (rare but seen in some vision transformers), Flash Attention cannot accelerate that attention computation.

CUDA graph compatibility. Flash Attention uses a variable amount of shared memory depending on the tile size, which can cause issues with CUDA graph capture. If you are using torch.compile with mode="reduce-overhead", test that the Flash Attention kernel does not prevent graph capture. v2.8.x has improved this, but the interaction is not guaranteed across all PyTorch versions.

AMD GPUs and non-CUDA backends. Flash Attention is a CUDA kernel. It does not run on AMD ROCm out of the box. The ROCm ecosystem has an alternative implementation called triton-based Flash Attention, but it has different performance characteristics and is not a drop-in replacement. If you are on AMD GPUs, benchmark before assuming parity.

Automatic fallback in SDPA can hide problems. Because PyTorch's SDPA silently falls back to the naive implementation if Flash Attention conditions are unmet, you can accidentally get different kernels on different GPU types and not notice. Always log which SDPA backend was selected if you care about reproducible performance.

When NOT to use it

Flash Attention is the wrong optimization if:

  • Your bottleneck is the MLP layers, not attention. For inference workloads where batch size is 1 and sequence length is short (under 512 tokens), the attention compute is a small fraction of total time. The MLP projections dominate. Optimizing attention gives you a 5-10% speedup instead of 2-4x. Profile first.

  • You are on CPU inference. Flash Attention requires a CUDA-capable GPU. CPUs use entirely different attention paths.

  • You need integer-only attention (e.g., quantized KV cache on CPU/edge devices). Flash Attention is implemented in CUDA and expects FP16/BF16 data. Quantized attention kernels (MatMul-free LLMs, etc.) use different algorithms.

  • You are training a small model for quick iteration. If your model takes 30 seconds per epoch, optimizing attention will not move the bottleneck. The overhead of importing and configuring Flash Attention (not large, but nonzero) is wasted effort.

  • Your sequence length is extremely long (100K+ tokens). For very long sequences, the memory-efficient attention in SDPA (which is Flash Attention for normal lengths) may still require an HBM pass that makes the tiling less effective. The Ring Attention / DeepSpeed Ulysses / Stripe Attention approaches are better suited above 100K tokens because they shard across GPUs instead of within a single GPU's SRAM.

TL;DR

  • Flash Attention tiles the Q, K, V matrices into blocks that fit in GPU SRAM, computing the softmax online without ever materializing the full N x N attention matrix in HBM.
  • v2.8.3.post1 is the current stable release (June 2026). v2 improved parallelism and removed length restrictions. v3 added H100-specific WGMMA instructions and FP8 support.
  • The speedup is 2-4x on A100-class GPUs, 3-7x on H100, at zero precision loss, with no model architecture changes required.
  • You get it automatically through PyTorch F.scaled_dot_product_attention or directly via the flash_attn package.
  • Watch for head_dim limits (max 256 in v2/v3), CUDA graph compatibility, and the silent SDPA backend fallback that can hide performance regressions.
  • Do not use Flash Attention if your bottleneck is not attention, you are on CPU/AMD, or you have extreme sequence lengths that require inter-GPU sharding.

Next post: a practical comparison of sampling strategies -- temperature, top-p, top-k, min-p, and what actually produces better output quality in production systems.

Top comments (0)