Table of content
- Introduction to FlashAttention
- Vanilla self-attention mechanism without FlashAttention
- Conceptual Overview: Fused Tiled Attention
- FlashAttention by hand
- Walkthrough of the FlashAttention diagram
- Appendix A - How to derive activation matrices from weight matrices
- Appendix B - The Paper's pseudocode vs. this walkthrough
Introduction to FlashAttention
FlashAttention builds upon the principles of online softmax to create an efficient single-pass algorithm for the entire self-attention mechanism. Its key innovation is to compute the final output O
(where
) directly, without ever forming the full attention matrix A
. This is achieved by fusing the matrix multiplications (
and
) and the online softmax calculation into a single GPU kernel.
By avoiding the need to write and read the large
attention matrix (where L
represents the number of input tokens) to and from global memory (DRAM), FlashAttention significantly reduces memory access, which is often the primary bottleneck in attention calculations. This makes it substantially faster and more memory-efficient, especially for long sequences.
This explanation will walk through the FlashAttention algorithm by hand, using the same tiled approach from the online softmax example in my previous article. This walkthrough is based on "From Online Softmax to FlashAttention" by Ye Zihao (2023).
This article was written with the assistance of Google Gemini 2.5 Pro.
Vanilla self-attention mechanism without FlashAttention
In order to better appreciate the efficiency gains with FlashAttention, we first begin with a vanilla implementation of self-attention. Let's illustrate this with our familiar 1 x 6
example used in the previous article on online softmax. Imagine this vector represents the dot products of a single query vector q
with six key vectors k
.
- Dot Products (Logits):
We also need a corresponding V
matrix. For simplicity, let's assume V
has 6 rows (one for each key) and a dimension of 2.
- Value Matrix:
The standard process is broken down into three distinct, sequential stages:
- Calculate Logits: Compute .
- Calculate Attention Scores: Compute . This creates a dense attention score matrix.
- Calculate Final Output: Compute .
Since the logits X
are already given, our walkthrough will start from Step 2.
(Note: Q
, K
and V
here are activation matrices which are result of multiplying the input embeddings by the weight matrices W_Q
, W_K
and W_V
. These activation matrices are different for every input sequence and are the actual inputs to the attention calculation kernel that FlashAttention replaces. Refer to Appendix A to understand how these activation matrices are derived from the weight matrices.)
Step 1: Calculate the Full Attention Score Matrix (A = softmax(X)
)
This step computes the softmax function over the entire logit vector X
. This is a multi-pass process, requiring a pass to find the maximum, a pass to compute the denominator, and a pass to compute the final probabilities. The key difference from FlashAttention is that we must compute and store this entire attention vector A
before we can even begin to use the V
matrix.
1a. Find the Global Maximum (m
)
First, we scan the entire logit vector X
to find its maximum value for numerical stability (the "safe softmax" trick).
1b. Compute the Exponentials and the Denominator (d
)
Next, we subtract the maximum from each logit, exponentiate the result, and sum them all up to get the denominator.
- Subtract max:
- Exponentiate:
- Sum to get the denominator
d
:
1c. Normalize to get the Final Attention Vector A
Finally, we divide the exponentiated values by the denominator d
to get the final probabilities. The result is the complete attention score vector A
.
At this point, the 1 x 6
vector A
is fully computed and stored in memory.
Step 2: Multiply Attention Scores by the Value Matrix ( )
Now that we have the attention scores, we can perform the final step: a matrix multiplication between A
and V
. The attention scores in A
act as weights for the corresponding value vectors in V
. The output O
is the weighted sum of the value vectors.
The operation is .
Let's calculate the weighted sum:
Now we compute each term and then sum them up:
Summing the vectors component-wise:
This result, [3.932, 3.932]
.
Comparison and Key Differences
Intermediate Matrix: The standard method materialized the full
1 x 6
attention vectorA
. In a real-world scenario with a sequence lengthL
, this would be a largeL x L
matrix. This is the main bottleneck. I will demonstrate how FlashAttention computes the final output without ever creating or storing this matrix.Memory Access: The standard method requires at least two major memory operations: writing the entire
A
matrix to memory (DRAM), and then reading it all back in to multiply withV
. I will show how FlashAttention fuses these operations, keeping tiles ofQ
,K
, andV
in fast SRAM and avoiding the slow roundtrip to DRAM.Computation Flow: The process is strictly sequential. You cannot start the multiplication until the entire softmax calculation for
A
is complete. We will show how FlashAttention integrates these steps, updating a running output vector as it iterates through tiles of theK
andV
matrices.
Conceptual Overview: Fused Tiled Attention
FlashAttention processes the input matrices Q
, K
, and V
in a tiled manner. For each row of the output matrix O
, it iterates through the corresponding rows of K
and V
in blocks. At each step, it calculates the attention scores for just that block, updates the running statistics (the maximum and the denominator, just like in online softmax), and immediately applies these scores to the corresponding block of V
to update a running output vector.
The core idea is to maintain three running statistics for each row of the output:
-
m_running
: The running maximum of the dot products (q · k
). -
d_running
: The running denominator of the softmax. -
o_running
: The running output vector, which is a weighted sum of theV
vectors, scaled by the current (and incomplete) softmax probabilities.
We will use the exact same input data as before with the 1 x 6
example. This vector represents the dot products of a single query vector q
with six key vectors k
.
- Dot Products (Logits):
We also need a corresponding V
matrix. For simplicity, let's assume V
has 6 rows (one for each key) and a dimension of 2.
- Value Matrix:
We will process this with a tile size of 3, breaking X
and V
into two tiles:
- : Logits and Values
- : Logits and Values
The algorithm follows the logic of "Algorithm FlashAttention (Tiling)" on page 6 of "From Online Softmax to FlashAttention" by Ye Zihao (2023).
(Note: For intuitive clarity, this walkthrough maintains an un-normalized running output o_running
and performs a single normalization at the end. The paper's pseudocode maintains a normalized running output o'
at each step. The final result is mathematically identical. An explanation of why this is true is in Appendix B.)
FlashAttention by hand
Before the main loop begins, the running statistics are initialized.
-
m_running = -∞
-
d_running = 0
-
o_running = [0, 0]
(a zero vector of the same dimension as av
vector)
Link back to the FlashAttention algorithm
These initial values correspond to the state before the for
loop, where
,
, and
.
Step 1: Process Tile 1 (i=1
)
We now begin the first iteration of the for i ← 1, #tiles do
loop.
1a. Find the New Maximum
First, we calculate the dot products (logits) for this tile and find the maximum value.
- Logits for Tile 1:
- Local max of Tile 1:
- New overall maximum:
Link back to the FlashAttention algorithm
- The calculation of logits corresponds to the first line inside the loop:
- Finding the local max corresponds to the second line:
- Updating the running max corresponds to the third line:
(Here,
is our m_running
from before the step).
1b. Calculate Local Denominator and Local Output
Next, we compute the un-normalized attention scores for this tile using m_new
, and then use them to calculate a local denominator d_local
and a local weighted output o_local
.
- Un-normalized scores:
- Local denominator:
- Local output (un-normalized weighted sum of
V
vectors):
Link back to the FlashAttention algorithm
- The calculation of corresponds to the summation part of the denominator update rule:
- The calculation of corresponds to the numerator of the second term in the output update rule (before normalization):
1c. Update Running Statistics
Now, we update our global running statistics.
-
m_old = -∞
,d_old = 0
,o_old = [0, 0]
-
m_new = 3
-
d_local = 1.5032
,o_local = [3.8711, 3.8711]
After processing the first tile, our statistics are: m_running = 3
, d_running ≈ 1.5032
, o_running ≈ [3.8711, 3.8711]
.
Link back to the FlashAttention algorithm
This entire step corresponds to the final two lines of the loop body for i=1
.
- Denominator update:
Since
, the first term is zero, and
becomes just the local sum, matching our result.
- Output update:
Similarly, since
and
are zero, the first term vanishes. Our un-normalized `o_running` is equivalent to
, which is simply the numerator of the second term, matching our calculation of
.
Step 2: Process Tile 2 (i=2
)
We now proceed to the second and final iteration of the loop.
2a. Find the New Maximum
- Logits for Tile 2:
- Local max of Tile 2:
- New overall maximum:
Link back to the FlashAttention algorithm
This again maps to the first three lines of the loop body for i=2
. This time,
, so the new maximum is correctly found as
.
2b. Calculate Local Denominator and Local Output
We repeat the process for the second tile's data, using the new max, 6.
- Un-normalized scores:
- Local denominator:
- Local output:
Link back to the FlashAttention algorithm
This again corresponds to the summation parts of the update rules for i=2
.
2c. Update Running Statistics
Now we perform the final update, including the crucial rescaling step.
-
m_old = 3
,d_old ≈ 1.5032
,o_old ≈ [3.8711, 3.8711]
-
m_new = 6
-
d_local ≈ 1.025
,o_local ≈ [4.1317, 4.1317]
First, update the denominator:
Next, update the output vector:
Link back to the FlashAttention algorithm
This step corresponds to the full update rules for i=2
.
- Denominator update:
The term
is the crucial rescaling factor, which perfectly matches the
part of our calculation.
- Output update:
Again, our un-normalized `o_running` is equivalent to
. If you multiply the paper's update rule by
, you get
. This exactly matches our formula:
Final Result
After the loop finishes, we have the final un-normalized output and the final denominator.
-
d_final ≈ 1.0998
-
o_final ≈ [4.3244, 4.3244]
The last step is to normalize the output vector by dividing by the final denominator.
Link back to the FlashAttention algorithm
This final normalization step is implicitly the result of the algorithm. The final output of the loop, , is the correctly normalized output row. Our method simply defers this division to the very end for clarity. The final output vector in the algorithm is this final, normalized value.
Walkthrough of the FlashAttention diagram
Having completed the step-by-step walkthrough of the FlashAttention algorithm, we can now walkthrough the FlashAttention schematics from the original paper ["FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness"(https://arxiv.org/abs/2205.14135) by Tri Dao et. al. (2022).
High-Level Overview
The diagram shows a tiled computation where the goal is to compute one block of the final output matrix, softmax(QK^T)V
, at a time. The loops control which tiles (blocks) of the input matrices (Q
, K
, V
) are loaded from slow global memory (HBM) into fast local memory (SRAM) to perform a piece of the calculation.
Our walkthrough, where we computed a single output row vector, corresponds to one full pass of the "Inner Loop" in this diagram.
Mapping the Diagram Components to Our Walkthrough
Let's look at each element of the diagram and connect it to our example.
1. The Loops
-
Inner Loop (Blue Arrows): This loop iterates over the queries (rows of
Q
) and the corresponding output rows. In our example, we only had a single query vectorq
that produced a single output rowo
. Therefore, our entire walkthrough represents a single iteration of the Inner Loop.- The
Q
matrix block being copied is our single queryq
. - The
Output to HBM
block at the bottom is ouro_running
vector
- The
-
Outer Loop (Red Arrows): This loop iterates over the key-value pairs (columns of
K^T
and rows ofV
). This loop is the core of the FlashAttention mechanism and maps directly to the steps of our walkthrough.- Iteration 1 of the Outer Loop corresponds to our "Step 1: Process Tile 1".
- Iteration 2 of the Outer Loop corresponds to our "Step 2: Process Tile 2".
2. The Memory Hierarchy (SRAM vs. HBM)
- HBM (High-Bandwidth Memory): This represents the GPU's main global memory (DRAM). This is where the full, large
Q
,K
,V
, and finalO
matrices reside. - SRAM (Fast On-Chip Memory): This is the small but extremely fast "workbench" memory. The diagram shows that we only ever copy small blocks (orange squares) into SRAM to work on them. In our walkthrough, the SRAM would hold:
- Our query vector
q
. - The current block of keys and values we are processing (e.g., and for Tile 1).
- The running statistics:
m_running
,d_running
, ando_running
.
- Our query vector
3. Putting it all Together: Tracing Our Walkthrough on the Diagram
Let's trace the flow for our single query q
.
Initialization:
- Before the loops start, the
Output to HBM
block (ouro_running
vector) is initialized to[0, 0]
. The running statisticsm_running = -∞
andd_running = 0
are initialized in SRAM.
Inner Loop Begins (One and only one iteration for our example):
-
Copy
fromQ
: Our query vectorq
is loaded from HBM into SRAM. It will stay in SRAM for the entire duration of the Outer Loop.
Outer Loop - Iteration 1 (Our "Process Tile 1"):
-
Copy
fromK^T
andV
: The first block ofK^T
(corresponding to logits1, 2, 3
) and the first block ofV
(v_1, v_2, v_3
) are loaded from HBM into SRAM. -
Compute Block on SRAM
: This is the central computation.- The dot product is calculated.
- The local max, local denominator, and local output are computed.
- The running statistics
m_running
,d_running
, ando_running
(which live in SRAM) are updated. After this step,o_running
is . The+
sign with the purple dotted arrow signifies this update step.
Outer Loop - Iteration 2 (Our "Process Tile 2"):
-
Copy
fromK^T
andV
: The previous blocks ofK^T
andV
are discarded. The second block (logits6, 2, 1
and valuesv_4, v_5, v_6
) is loaded into SRAM. -
Compute Block on SRAM
: The computation is repeated.- A new, larger global maximum (
m_new = 6
) is found. - This triggers the crucial rescaling of the existing
d_running
ando_running
vectors. - The local contributions are calculated and added. After this step, the final un-normalized
o_running
is .
- A new, larger global maximum (
Outer Loop Finishes:
- The loop is complete. The final normalization is performed in SRAM ( ) to get the final output vector .
Output to HBM
:
- The final, correct output vector is written from SRAM back to its designated row in the main output matrix in HBM.
Inner Loop Finishes:
- If there are more rows in
Q
, the Inner Loop will continue to the next row and repeat the entire Outer Loop process.
Appendix A - How to derive activation matrices from weight matrices
For the purpose of focusing on the core concepts of FlashAttention, the walkthrough focused exclusively on the core attention calculation step: . This is the specific operation that FlashAttention optimizes.
This appendix provides the "prequel" step that was omitted to derive the activation matrices.
The Two Sets of Q, K, V
It's crucial to distinguish between:
The Weight Matrices (
W_Q
,W_K
,W_V
): These are the trainable parameters learned during the training of the Transformer model. They are part of a standard Linear layer. Their job is to project the input token embeddings into the query, key, and value spaces. They are the same for every input sequence.The Activation Matrices (
Q
,K
,V
): These are the intermediate representations or activations. They are the result of multiplying the input embeddings by the weight matrices. These matrices are different for every input sequence and are the actual inputs to the attention calculation kernel that FlashAttention replaces. They are not trainable parameters themselves.
Think of it like a recipe:
- The weight matrices (
W_Q
,W_K
,W_V
) are the instructions in the recipe book (fixed, learned). - The activation matrices (
Q
,K
,V
) are the actual ingredients you've prepared for one specific meal (changes every time you cook).
FlashAttention's innovation is in how to efficiently combine the prepared ingredients (Q
, K
, V
), not in the initial preparation step itself.
The Omitted "Prequel" Step: Creating Q, K, and V
Here is the step that happens before our walkthrough begins.
Let's assume we have an input sequence of 6 tokens, and each token has an embedding dimension of 4. This is our input matrix X
(not to be confused with the logit vector X
from the walkthrough).
- Input Embeddings
X
(size 6x4):
Now, let's define our trainable weight matrices. Let's say the head dimension d
is 2. The weight matrices must project from the embedding dimension (4) to the head dimension (2). So, they will all be size 4x2.
- Weight Matrices
W_Q
,W_K
,W_V
(size 4x2, trainable):
The Q
, K
, and V
activation matrices are created with standard matrix multiplication:
The V
matrix we get from this calculation is precisely the V
matrix we used in the walkthrough:
Connecting to the Logits
In our walkthrough, we focused on a single query vector q
attending to all the keys. This corresponds to taking the first row of the Q
matrix.
The logit vector X
from the walkthrough would then be calculated as:
This multiplication would result in the 1 x 6
vector we started with:
Summary
So, the full, un-omitted process is:
-
Linear Projections (The Omitted Prequel):
- Start with input embeddings
X
. - Compute , , using the trainable weight matrices. This is done with standard, highly optimized matrix multiplication libraries (GEMM).
- Start with input embeddings
-
FlashAttention Calculation (The Walkthrough):
- Take the resulting
Q
,K
, andV
activation matrices as input. - Efficiently compute in a single kernel without materializing the full attention matrix.
- Take the resulting
Appendix B - The Paper's pseudocode vs. this walkthrough
The goal of my walkthrough was to make the calculation as intuitive as possible to follow by hand. To do this, I slightly rearranged the math to keep the intermediate numbers as simple as possible, while still being mathematically equivalent to the paper's algorithm.
1. The Paper's "Algorithm FlashAttention (Tiling)"
The algorithm in the PDF maintains a normalized output vector o'
at every step. Let's look closely at the update rule for the output:
Notice that both terms are divided by
, the new total denominator. This means that at the end of each iteration i
, the vector o'_i
is the correctly normalized attention output for all the data processed up to that tile.
2. My Walkthrough's Method
Doing this normalization (division) at every single step can be cumbersome for a manual example. It's easier to work with un-normalized sums and perform a single division at the very end.
So, this walkthrough maintains an un-normalized running output o_running
. Our update rule was:
This is mathematically equivalent to the numerator of the paper's formula. You can see this if you multiply the paper's entire update rule by :
If we define the as the paper's , then this equation is exactly the one used in the walkthrough:
- My is
- My is
- My is
The note was an attempt to explain this simplification: we chose to track the un-normalized numerator throughout the process for clarity and then perform the final division only once.
Why normalizing only at the last step works
The reason this works is that the update rules for the un-normalized numerator (o_running) and the denominator (d_running) are designed to maintain a consistent relationship. At every step i, the correctly normalized output is simply the ratio of the running numerator to the running denominator at that step. By deferring the division to the end, we arrive at the same final ratio. The proof by induction below demonstrates this formally.
The Two Methods
Let's formally define the two methods we are comparing. We will use the paper's notation where
is the normalized output after tile i
, and we'll introduce
(capital O
) as our un-normalized running output numerator from the walkthrough.
Method 1: Normalize at Each Step (The Paper's Algorithm)
The state after tile i
is defined by:
The final result is
after the last tile N
.
Method 2: Normalize Only at the End (The Walkthrough's Method)
The state after tile i
is defined by an un-normalized numerator
and the same denominator
:
The final result is calculated as at the very end.
The Proof of Equivalence
We want to prove that
. We can prove this by induction, showing that the relationship
holds true for every step i
.
1. Base Case (i=1)
Let's check the first tile. Both methods start with , , and .
-
Method 1:
-
Method 2:
Comparing the two, we see they are identical. The base case holds.
2. Inductive Hypothesis
Assume that the relationship is true for step i-1
. That is, assume:
3. Inductive Step
Now we must prove that the relationship holds for step i
. Let's start with the formula for
from Method 1 and show it equals
.
Start with the definition of
:
Let's combine the two fractions over the common denominator
:
Now, look at the term in the parentheses: . According to our Inductive Hypothesis, this is exactly equal to . Let's substitute it in:
Now, look at the entire numerator: . This is precisely the definition of from Method 2.
So, we can substitute
for the numerator:
This completes the proof. We have shown that if the relationship holds for step i-1
, it must hold for step i
. Since it holds for the base case i=1
, it holds for all steps.
Intuitive Analogy: Calculating a Weighted Average
Think about calculating the final grade in a class. You have different assignments with different weights (scores).
Method 1 (Normalize at Each Step): After you get your first grade, you calculate your current average. When you get your second grade, you update your average based on the new grade and its weight relative to the first. You keep re-calculating your "running average" after every single assignment.
Method 2 (Normalize at the End): You collect all your points scored for each assignment (
score * weight
) in one column. You collect the total possible points (sum of all weights
) in another column. You do this for the whole semester. At the very end, you do one single division:(total points scored) / (total possible points)
.
Both methods give you the exact same final grade. The second method simply defers the division. We can apply the same principle in FlashAttention: accumulate the un-normalized numerator (o_running
) and the un-normalized denominator (d_running
) separately and effectively perform the division at the end.
Top comments (0)