DEV Community

Cover image for FlashAttention by hand
Lewis Won
Lewis Won

Posted on

FlashAttention by hand

Table of content

Introduction to FlashAttention

FlashAttention builds upon the principles of online softmax to create an efficient single-pass algorithm for the entire self-attention mechanism. Its key innovation is to compute the final output O (where O=softmax(QKT)VO = \text{softmax}(QK^T)V ) directly, without ever forming the full attention matrix A. This is achieved by fusing the matrix multiplications ( QKTQK^T and AVAV ) and the online softmax calculation into a single GPU kernel.

By avoiding the need to write and read the large L×LL \times L attention matrix (where L represents the number of input tokens) to and from global memory (DRAM), FlashAttention significantly reduces memory access, which is often the primary bottleneck in attention calculations. This makes it substantially faster and more memory-efficient, especially for long sequences.

This explanation will walk through the FlashAttention algorithm by hand, using the same tiled approach from the online softmax example in my previous article. This walkthrough is based on "From Online Softmax to FlashAttention" by Ye Zihao (2023).

This article was written with the assistance of Google Gemini 2.5 Pro.


Vanilla self-attention mechanism without FlashAttention

In order to better appreciate the efficiency gains with FlashAttention, we first begin with a vanilla implementation of self-attention. Let's illustrate this with our familiar 1 x 6 example used in the previous article on online softmax. Imagine this vector represents the dot products of a single query vector q with six key vectors k.

  • Dot Products (Logits):
    X=qKT=[123621]X = qK^T = \begin{bmatrix} 1 & 2 & 3 & 6 & 2 & 1 \end{bmatrix}

We also need a corresponding V matrix. For simplicity, let's assume V has 6 rows (one for each key) and a dimension of 2.

  • Value Matrix:
    V=[v1v2v3v4v5v6]=[112233445566]V = \begin{bmatrix} v_1 \newline v_2 \newline v_3 \newline v_4 \newline v_5 \newline v_6 \end{bmatrix} = \begin{bmatrix} 1 & 1 \newline 2 & 2 \newline 3 & 3 \newline 4 & 4 \newline 5 & 5 \newline 6 & 6 \end{bmatrix}

The standard process is broken down into three distinct, sequential stages:

  1. Calculate Logits: Compute X=QKTX = QK^T .
  2. Calculate Attention Scores: Compute A=softmax(X)A = \text{softmax}(X) . This creates a dense attention score matrix.
  3. Calculate Final Output: Compute O=AVO = A \cdot V .

Since the logits X are already given, our walkthrough will start from Step 2.

(Note: Q, K and V here are activation matrices which are result of multiplying the input embeddings by the weight matrices W_Q, W_K and W_V. These activation matrices are different for every input sequence and are the actual inputs to the attention calculation kernel that FlashAttention replaces. Refer to Appendix A to understand how these activation matrices are derived from the weight matrices.)


Step 1: Calculate the Full Attention Score Matrix (A = softmax(X))

This step computes the softmax function over the entire logit vector X. This is a multi-pass process, requiring a pass to find the maximum, a pass to compute the denominator, and a pass to compute the final probabilities. The key difference from FlashAttention is that we must compute and store this entire attention vector A before we can even begin to use the V matrix.

1a. Find the Global Maximum (m)

First, we scan the entire logit vector X to find its maximum value for numerical stability (the "safe softmax" trick).

m=max([123621])=6 m = \max(\begin{bmatrix} 1 & 2 & 3 & 6 & 2 & 1 \end{bmatrix}) = 6

1b. Compute the Exponentials and the Denominator (d)

Next, we subtract the maximum from each logit, exponentiate the result, and sum them all up to get the denominator.

  • Subtract max:
    Xm=[162636662616]=[543045]X - m = \begin{bmatrix} 1-6 & 2-6 & 3-6 & 6-6 & 2-6 & 1-6 \end{bmatrix} = \begin{bmatrix} -5 & -4 & -3 & 0 & -4 & -5 \end{bmatrix}
  • Exponentiate:
    e(Xm)=[e5e4e3e0e4e5][0.00670.01830.049810.01830.0067]e^{(X-m)} = \begin{bmatrix} e^{-5} & e^{-4} & e^{-3} & e^{0} & e^{-4} & e^{-5} \end{bmatrix} \approx \begin{bmatrix} 0.0067 & 0.0183 & 0.0498 & 1 & 0.0183 & 0.0067 \end{bmatrix}
  • Sum to get the denominator d:
    d=e(Xm)0.0067+0.0183+0.0498+1+0.0183+0.0067=1.0998d = \sum e^{(X-m)} \approx 0.0067 + 0.0183 + 0.0498 + 1 + 0.0183 + 0.0067 = 1.0998

1c. Normalize to get the Final Attention Vector A

Finally, we divide the exponentiated values by the denominator d to get the final probabilities. The result is the complete attention score vector A.

A=e(Xm)d[0.00670.01830.049810.01830.0067]1.0998[0.00610.01670.04530.90920.01670.0061] \begin{aligned} A &= \frac{e^{(X-m)}}{d} \newline &\approx \frac{\begin{bmatrix} 0.0067 & 0.0183 & 0.0498 & 1 & 0.0183 & 0.0067 \end{bmatrix}}{1.0998} \newline &\approx \begin{bmatrix} 0.0061 & 0.0167 & 0.0453 & 0.9092 & 0.0167 & 0.0061 \end{bmatrix} \end{aligned}

At this point, the 1 x 6 vector A is fully computed and stored in memory.


Step 2: Multiply Attention Scores by the Value Matrix ( O=AVO = A \cdot V )

Now that we have the attention scores, we can perform the final step: a matrix multiplication between A and V. The attention scores in A act as weights for the corresponding value vectors in V. The output O is the weighted sum of the value vectors.

The operation is (1×6)(6×2)(1×2)(1 \times 6) \cdot (6 \times 2) \rightarrow (1 \times 2) .

O=AV[0.00610.01670.04530.90920.01670.0061][112233445566] O = A \cdot V \approx \begin{bmatrix} 0.0061 & 0.0167 & 0.0453 & 0.9092 & 0.0167 & 0.0061 \end{bmatrix} \cdot \begin{bmatrix} 1 & 1 \newline 2 & 2 \newline 3 & 3 \newline 4 & 4 \newline 5 & 5 \newline 6 & 6 \end{bmatrix}

Let's calculate the weighted sum:

O(0.0061[11])+(0.0167[22])+(0.0453[33])+(0.9092[44])+(0.0167[55])+(0.0061[66]) \begin{aligned} O \approx & (0.0061 \cdot \begin{bmatrix}1 & 1\end{bmatrix}) + \newline & (0.0167 \cdot \begin{bmatrix}2 & 2\end{bmatrix}) + \newline & (0.0453 \cdot \begin{bmatrix}3 & 3\end{bmatrix}) + \newline & (0.9092 \cdot \begin{bmatrix}4 & 4\end{bmatrix}) + \newline & (0.0167 \cdot \begin{bmatrix}5 & 5\end{bmatrix}) + \newline & (0.0061 \cdot \begin{bmatrix}6 & 6\end{bmatrix}) \end{aligned}

Now we compute each term and then sum them up:

O[0.00610.0061]+[0.03340.0334]+[0.13590.1359]+[3.63683.6368]+[0.08350.0835]+[0.03660.0366] \begin{aligned} O \approx & \begin{bmatrix}0.0061 & 0.0061\end{bmatrix} + \newline & \begin{bmatrix}0.0334 & 0.0334\end{bmatrix} + \newline & \begin{bmatrix}0.1359 & 0.1359\end{bmatrix} + \newline & \begin{bmatrix}3.6368 & 3.6368\end{bmatrix} + \newline & \begin{bmatrix}0.0835 & 0.0835\end{bmatrix} + \newline & \begin{bmatrix}0.0366 & 0.0366\end{bmatrix} \end{aligned}

Summing the vectors component-wise:

O[3.93233.9323] O \approx \begin{bmatrix}3.9323 & 3.9323\end{bmatrix}

This result, [3.932, 3.932].


Comparison and Key Differences

  • Intermediate Matrix: The standard method materialized the full 1 x 6 attention vector A. In a real-world scenario with a sequence length L, this would be a large L x L matrix. This is the main bottleneck. I will demonstrate how FlashAttention computes the final output without ever creating or storing this matrix.

  • Memory Access: The standard method requires at least two major memory operations: writing the entire A matrix to memory (DRAM), and then reading it all back in to multiply with V. I will show how FlashAttention fuses these operations, keeping tiles of Q, K, and V in fast SRAM and avoiding the slow roundtrip to DRAM.

  • Computation Flow: The process is strictly sequential. You cannot start the AVA \cdot V multiplication until the entire softmax calculation for A is complete. We will show how FlashAttention integrates these steps, updating a running output vector as it iterates through tiles of the K and V matrices.


Conceptual Overview: Fused Tiled Attention

FlashAttention processes the input matrices Q, K, and V in a tiled manner. For each row of the output matrix O, it iterates through the corresponding rows of K and V in blocks. At each step, it calculates the attention scores for just that block, updates the running statistics (the maximum and the denominator, just like in online softmax), and immediately applies these scores to the corresponding block of V to update a running output vector.

The core idea is to maintain three running statistics for each row of the output:

  1. m_running: The running maximum of the dot products (q · k).
  2. d_running: The running denominator of the softmax.
  3. o_running: The running output vector, which is a weighted sum of the V vectors, scaled by the current (and incomplete) softmax probabilities.

We will use the exact same input data as before with the 1 x 6 example. This vector represents the dot products of a single query vector q with six key vectors k.

  • Dot Products (Logits):
    X=qKT=[123621]X = qK^T = \begin{bmatrix} 1 & 2 & 3 & 6 & 2 & 1 \end{bmatrix}

We also need a corresponding V matrix. For simplicity, let's assume V has 6 rows (one for each key) and a dimension of 2.

  • Value Matrix:
    V=[v1v2v3v4v5v6]=[112233445566]V = \begin{bmatrix} v_1 \newline v_2 \newline v_3 \newline v_4 \newline v_5 \newline v_6 \end{bmatrix} = \begin{bmatrix} 1 & 1 \newline 2 & 2 \newline 3 & 3 \newline 4 & 4 \newline 5 & 5 \newline 6 & 6 \end{bmatrix}

We will process this with a tile size of 3, breaking X and V into two tiles:

  • T1T_1 : Logits [123]\begin{bmatrix} 1 & 2 & 3 \end{bmatrix} and Values [v1v2v3]\begin{bmatrix} v_1 \newline v_2 \newline v_3 \end{bmatrix}
  • T2T_2 : Logits [621]\begin{bmatrix} 6 & 2 & 1 \end{bmatrix} and Values [v4v5v6]\begin{bmatrix} v_4 \newline v_5 \newline v_6 \end{bmatrix}

The algorithm follows the logic of "Algorithm FlashAttention (Tiling)" on page 6 of "From Online Softmax to FlashAttention" by Ye Zihao (2023).

tiled algorithm

(Note: For intuitive clarity, this walkthrough maintains an un-normalized running output o_running and performs a single normalization at the end. The paper's pseudocode maintains a normalized running output o' at each step. The final result is mathematically identical. An explanation of why this is true is in Appendix B.)


FlashAttention by hand

Before the main loop begins, the running statistics are initialized.

  • m_running = -∞
  • d_running = 0
  • o_running = [0, 0] (a zero vector of the same dimension as a v vector)

Link back to the FlashAttention algorithm

These initial values correspond to the state before the for loop, where m0=m_0 = -\infty , d0=0d^\prime_0 = 0 , and o0=0o^\prime_0 = \vec{0} .


Step 1: Process Tile 1 (i=1)

We now begin the first iteration of the for i ← 1, #tiles do loop.

1a. Find the New Maximum

First, we calculate the dot products (logits) for this tile and find the maximum value.

  • Logits for Tile 1:
    x1=[123]x_1 = \begin{bmatrix} 1 & 2 & 3 \end{bmatrix}
  • Local max of Tile 1:
    mT1=max([123])=3m_{T_1} = \max(\begin{bmatrix} 1 & 2 & 3 \end{bmatrix}) = 3
  • New overall maximum:
    mnew=max(mrunning,mT1)=max(,3)=3m_{new} = \max(m_{running}, m_{T_1}) = \max(-\infty, 3) = 3

Link back to the FlashAttention algorithm

  • The calculation of logits corresponds to the first line inside the loop:
xiQ[k,:]KT[:,(i1)b:ib] x_i \leftarrow Q[k,:] K^T[:, (i-1)b:ib]
  • Finding the local max corresponds to the second line:
mi(local)=maxj=1..b(xi[j]) m_i^{(\text{local})} = \max_{j=1..b}(x_i[j])
  • Updating the running max corresponds to the third line:
mimax(mi1,mi(local)) m_i \leftarrow \max(m_{i-1}, m_i^{(\text{local})})

(Here, mi1m_{i-1} is our m_running from before the step).

1b. Calculate Local Denominator and Local Output

Next, we compute the un-normalized attention scores for this tile using m_new, and then use them to calculate a local denominator d_local and a local weighted output o_local.

  • Un-normalized scores:
    P1=[e13e23e33]=[e2e1e0][0.13530.36791]P_1 = \begin{bmatrix} e^{1-3} \newline e^{2-3} \newline e^{3-3} \end{bmatrix} = \begin{bmatrix} e^{-2} \newline e^{-1} \newline e^{0} \end{bmatrix} \approx \begin{bmatrix} 0.1353 \newline 0.3679 \newline 1 \end{bmatrix}
  • Local denominator:
    dT1=P10.1353+0.3679+1=1.5032d_{T_1} = \sum P_1 \approx 0.1353 + 0.3679 + 1 = 1.5032
  • Local output (un-normalized weighted sum of V vectors):
    oT1=(0.1353v1)+(0.3679v2)+(1v3)(0.1353[11])+(0.3679[22])+(1[33])[0.13530.1353]+[0.73580.7358]+[33]=[3.87113.8711]\begin{aligned} o_{T_1} &= (0.1353 \cdot v_1) + (0.3679 \cdot v_2) + (1 \cdot v_3) \newline &\approx (0.1353 \cdot \begin{bmatrix}1 & 1\end{bmatrix}) + (0.3679 \cdot \begin{bmatrix}2 & 2\end{bmatrix}) + (1 \cdot \begin{bmatrix}3 & 3\end{bmatrix}) \newline &\approx \begin{bmatrix}0.1353 & 0.1353\end{bmatrix} + \begin{bmatrix}0.7358 & 0.7358\end{bmatrix} + \begin{bmatrix}3 & 3\end{bmatrix} \newline &= \begin{bmatrix}3.8711 & 3.8711\end{bmatrix} \end{aligned}

Link back to the FlashAttention algorithm

  • The calculation of dT1d_{T_1} corresponds to the summation part of the denominator update rule:
j=1bexi[j]mi \sum_{j=1}^b e^{x_i[j] - m_i}
  • The calculation of oT1o_{T_1} corresponds to the numerator of the second term in the output update rule (before normalization):
j=1bexi[j]miV[j+(i1)b,:]\sum_{j=1}^b e^{x_i[j]-m_i} V[j + (i-1)b, :]

1c. Update Running Statistics

Now, we update our global running statistics.

dnew=doldemoldmnew+dlocalonew=ooldemoldmnew+olocal \begin{aligned} d_{new} &= d_{old} \cdot e^{m_{old} - m_{new}} + d_{local} \newline o_{new} &= o_{old} \cdot e^{m_{old} - m_{new}} + o_{local} \end{aligned}
  • m_old = -∞, d_old = 0, o_old = [0, 0]
  • m_new = 3
  • d_local = 1.5032, o_local = [3.8711, 3.8711]
drunning=(0e3)+1.5032=1.5032orunning=([00]e3)+[3.87113.8711]=[3.87113.8711] \begin{aligned} d_{running} &= (0 \cdot e^{-\infty - 3}) + 1.5032 = 1.5032 \newline o_{running} &= (\begin{bmatrix}0 & 0\end{bmatrix} \cdot e^{-\infty - 3}) + \begin{bmatrix}3.8711 & 3.8711\end{bmatrix} = \begin{bmatrix}3.8711 & 3.8711\end{bmatrix} \end{aligned}

After processing the first tile, our statistics are: m_running = 3, d_running ≈ 1.5032, o_running ≈ [3.8711, 3.8711].

Link back to the FlashAttention algorithm

This entire step corresponds to the final two lines of the loop body for i=1.

  • Denominator update:
didi1emi1mi+j=1bexi[j]mi d^\prime_i \leftarrow d^\prime_{i-1}e^{m_{i-1}-m_i} + \sum_{j=1}^b e^{x_i[j]-m_i}
Since 

  d0=0d^\prime_0 = 0

, the first term is zero, and 

  d1d^\prime_1

 becomes just the local sum, matching our result.
Enter fullscreen mode Exit fullscreen mode
  • Output update:
oioi1di1emi1midi+exi[j]miV[]di o^\prime_i \leftarrow o^\prime_{i-1}\frac{d^\prime_{i-1}e^{m_{i-1}-m_i}}{d^\prime_i} + \frac{\sum e^{x_i[j]-m_i}V[\dots]}{d^\prime_i}
Similarly, since 

  o0o^\prime_0

 and 

  d0d^\prime_0

 are zero, the first term vanishes. Our un-normalized `o_running` is equivalent to 

  o1d1o^\prime_1 \cdot d^\prime_1

, which is simply the numerator of the second term, matching our calculation of 

  oT1o_{T_1}

.
Enter fullscreen mode Exit fullscreen mode

Step 2: Process Tile 2 (i=2)

We now proceed to the second and final iteration of the loop.

2a. Find the New Maximum

  • Logits for Tile 2:
    x2=[621]x_2 = \begin{bmatrix} 6 & 2 & 1 \end{bmatrix}
  • Local max of Tile 2:
    mT2=max([621])=6m_{T_2} = \max(\begin{bmatrix} 6 & 2 & 1 \end{bmatrix}) = 6
  • New overall maximum:
    mnew=max(mrunning,mT2)=max(3,6)=6m_{new} = \max(m_{running}, m_{T_2}) = \max(3, 6) = 6

Link back to the FlashAttention algorithm

This again maps to the first three lines of the loop body for i=2. This time, mi1=m1=3m_{i-1} = m_1 = 3 , so the new maximum is correctly found as max(3,6)=6\max(3, 6) = 6 .

2b. Calculate Local Denominator and Local Output

We repeat the process for the second tile's data, using the new max, 6.

  • Un-normalized scores:
    P2=[e66e26e16]=[e0e4e5][10.01830.0067]P_2 = \begin{bmatrix} e^{6-6} \newline e^{2-6} \newline e^{1-6} \end{bmatrix} = \begin{bmatrix} e^{0} \newline e^{-4} \newline e^{-5} \end{bmatrix} \approx \begin{bmatrix} 1 \newline 0.0183 \newline 0.0067 \end{bmatrix}
  • Local denominator:
    dT2=P21+0.0183+0.0067=1.025d_{T_2} = \sum P_2 \approx 1 + 0.0183 + 0.0067 = 1.025
  • Local output:
    oT2=(1v4)+(0.0183v5)+(0.0067v6)(1[44])+(0.0183[55])+(0.0067[66])[44]+[0.09150.0915]+[0.04020.0402]=[4.13174.1317]\begin{aligned} o_{T_2} &= (1 \cdot v_4) + (0.0183 \cdot v_5) + (0.0067 \cdot v_6) \newline &\approx (1 \cdot \begin{bmatrix}4 & 4\end{bmatrix}) + (0.0183 \cdot \begin{bmatrix}5 & 5\end{bmatrix}) + (0.0067 \cdot \begin{bmatrix}6 & 6\end{bmatrix}) \newline &\approx \begin{bmatrix}4 & 4\end{bmatrix} + \begin{bmatrix}0.0915 & 0.0915\end{bmatrix} + \begin{bmatrix}0.0402 & 0.0402\end{bmatrix} \newline &= \begin{bmatrix}4.1317 & 4.1317\end{bmatrix} \end{aligned}

Link back to the FlashAttention algorithm

This again corresponds to the summation parts of the update rules for i=2.

2c. Update Running Statistics

Now we perform the final update, including the crucial rescaling step.

  • m_old = 3, d_old ≈ 1.5032, o_old ≈ [3.8711, 3.8711]
  • m_new = 6
  • d_local ≈ 1.025, o_local ≈ [4.1317, 4.1317]

First, update the denominator:

drunning=doldemoldmnew+dlocal(1.5032e36)+1.025(1.50320.04979)+1.0250.0748+1.025=1.0998 \begin{aligned} d_{running} &= d_{old} \cdot e^{m_{old} - m_{new}} + d_{local} \newline &\approx (1.5032 \cdot e^{3 - 6}) + 1.025 \newline &\approx (1.5032 \cdot 0.04979) + 1.025 \newline &\approx 0.0748 + 1.025 = 1.0998 \end{aligned}

Next, update the output vector:

orunning=ooldemoldmnew+olocal([3.87113.8711]e36)+[4.13174.1317]([3.87113.8711]0.04979)+[4.13174.1317][0.19270.1927]+[4.13174.1317]=[4.32444.3244] \begin{aligned} o_{running} &= o_{old} \cdot e^{m_{old} - m_{new}} + o_{local} \newline &\approx (\begin{bmatrix}3.8711 & 3.8711\end{bmatrix} \cdot e^{3 - 6}) + \begin{bmatrix}4.1317 & 4.1317\end{bmatrix} \newline &\approx (\begin{bmatrix}3.8711 & 3.8711\end{bmatrix} \cdot 0.04979) + \begin{bmatrix}4.1317 & 4.1317\end{bmatrix} \newline &\approx \begin{bmatrix}0.1927 & 0.1927\end{bmatrix} + \begin{bmatrix}4.1317 & 4.1317\end{bmatrix} \newline &= \begin{bmatrix}4.3244 & 4.3244\end{bmatrix} \end{aligned}

Link back to the FlashAttention algorithm

This step corresponds to the full update rules for i=2.

  • Denominator update:
d2d1em1m2+j=1bex2[j]m2 d^\prime_2 \leftarrow d^\prime_1 e^{m_1-m_2} + \sum_{j=1}^b e^{x_2[j]-m_2}
The term 

  d1em1m2d^\prime_1 e^{m_1-m_2}

 is the crucial rescaling factor, which perfectly matches the 

  (1.5032e36)(1.5032 \cdot e^{3 - 6})

 part of our calculation.
Enter fullscreen mode Exit fullscreen mode
  • Output update:
o2o1d1em1m2d2+ex2[j]m2V[]d2 o^\prime_2 \leftarrow o^\prime_1\frac{d^\prime_1e^{m_1-m_2}}{d^\prime_2} + \frac{\sum e^{x_2[j]-m_2}V[\dots]}{d^\prime_2}
Again, our un-normalized `o_running` is equivalent to 

  o2d2o^\prime_2 \cdot d^\prime_2

. If you multiply the paper's update rule by 

  d2d^\prime_2

, you get 

  (o1d1)em1m2+local sum(o^\prime_1 d^\prime_1) e^{m_1-m_2} + \text{local sum}

. This exactly matches our formula:
Enter fullscreen mode Exit fullscreen mode
ooldemoldmnew+olocalo_{old} \cdot e^{m_{old} - m_{new}} + o_{local}

Final Result

After the loop finishes, we have the final un-normalized output and the final denominator.

  • d_final ≈ 1.0998
  • o_final ≈ [4.3244, 4.3244]

The last step is to normalize the output vector by dividing by the final denominator.

Ofinal=ofinaldfinal[4.32444.3244]1.0998[3.9323.932] \begin{aligned} O_{final} &= \frac{o_{final}}{d_{final}} \newline &\approx \frac{\begin{bmatrix}4.3244 & 4.3244\end{bmatrix}}{1.0998} \newline &\approx \begin{bmatrix}3.932 & 3.932\end{bmatrix} \end{aligned}

Link back to the FlashAttention algorithm

This final normalization step is implicitly the result of the algorithm. The final output of the loop, oN/bo^\prime_{N/b} , is the correctly normalized output row. Our method simply defers this division to the very end for clarity. The final output vector O[k,:]O[k, :] in the algorithm is this final, normalized value.

O[k,:]oN/b O[k,:] \leftarrow o^\prime_{N/b}

Walkthrough of the FlashAttention diagram

Having completed the step-by-step walkthrough of the FlashAttention algorithm, we can now walkthrough the FlashAttention schematics from the original paper ["FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness"(https://arxiv.org/abs/2205.14135) by Tri Dao et. al. (2022).

FlashAttention Diagram

High-Level Overview

The diagram shows a tiled computation where the goal is to compute one block of the final output matrix, softmax(QK^T)V, at a time. The loops control which tiles (blocks) of the input matrices (Q, K, V) are loaded from slow global memory (HBM) into fast local memory (SRAM) to perform a piece of the calculation.

Our walkthrough, where we computed a single output row vector, corresponds to one full pass of the "Inner Loop" in this diagram.


Mapping the Diagram Components to Our Walkthrough

Let's look at each element of the diagram and connect it to our example.

1. The Loops

  • Inner Loop (Blue Arrows): This loop iterates over the queries (rows of Q) and the corresponding output rows. In our example, we only had a single query vector q that produced a single output row o. Therefore, our entire walkthrough represents a single iteration of the Inner Loop.

    • The Q matrix block being copied is our single query q.
    • The Output to HBM block at the bottom is our o_running vector
      [o1o2]\begin{bmatrix}o_1 & o_2\end{bmatrix}
      , which is being progressively built.
  • Outer Loop (Red Arrows): This loop iterates over the key-value pairs (columns of K^T and rows of V). This loop is the core of the FlashAttention mechanism and maps directly to the steps of our walkthrough.

    • Iteration 1 of the Outer Loop corresponds to our "Step 1: Process Tile 1".
    • Iteration 2 of the Outer Loop corresponds to our "Step 2: Process Tile 2".

2. The Memory Hierarchy (SRAM vs. HBM)

  • HBM (High-Bandwidth Memory): This represents the GPU's main global memory (DRAM). This is where the full, large Q, K, V, and final O matrices reside.
  • SRAM (Fast On-Chip Memory): This is the small but extremely fast "workbench" memory. The diagram shows that we only ever copy small blocks (orange squares) into SRAM to work on them. In our walkthrough, the SRAM would hold:
    • Our query vector q.
    • The current block of keys and values we are processing (e.g., k1,k2,k3k_1, k_2, k_3 and v1,v2,v3v_1, v_2, v_3 for Tile 1).
    • The running statistics: m_running, d_running, and o_running.

3. Putting it all Together: Tracing Our Walkthrough on the Diagram

Let's trace the flow for our single query q.

Initialization:

  • Before the loops start, the Output to HBM block (our o_running vector) is initialized to [0, 0]. The running statistics m_running = -∞ and d_running = 0 are initialized in SRAM.

Inner Loop Begins (One and only one iteration for our example):

  • Copy from Q: Our query vector q is loaded from HBM into SRAM. It will stay in SRAM for the entire duration of the Outer Loop.

Outer Loop - Iteration 1 (Our "Process Tile 1"):

  1. Copy from K^T and V: The first block of K^T (corresponding to logits 1, 2, 3) and the first block of V (v_1, v_2, v_3) are loaded from HBM into SRAM.
  2. Compute Block on SRAM: This is the central computation.
    • The dot product qKblock 1Tq \cdot K^T_{\text{block 1}} is calculated.
    • The local max, local denominator, and local output are computed.
    • The running statistics m_running, d_running, and o_running (which live in SRAM) are updated. After this step, o_running is [3.8711,3.8711]\approx [3.8711, 3.8711] . The + sign with the purple dotted arrow signifies this update step.

Outer Loop - Iteration 2 (Our "Process Tile 2"):

  1. Copy from K^T and V: The previous blocks of K^T and V are discarded. The second block (logits 6, 2, 1 and values v_4, v_5, v_6) is loaded into SRAM.
  2. Compute Block on SRAM: The computation is repeated.
    • A new, larger global maximum (m_new = 6) is found.
    • This triggers the crucial rescaling of the existing d_running and o_running vectors.
    • The local contributions are calculated and added. After this step, the final un-normalized o_running is [4.3244,4.3244]\approx [4.3244, 4.3244] .

Outer Loop Finishes:

  • The loop is complete. The final normalization is performed in SRAM ( orunning/drunningo_{\text{running}} / d_{\text{running}} ) to get the final output vector [3.932,3.932]\approx [3.932, 3.932] .

Output to HBM:

  • The final, correct output vector is written from SRAM back to its designated row in the main output matrix in HBM.

Inner Loop Finishes:

  • If there are more rows in Q, the Inner Loop will continue to the next row and repeat the entire Outer Loop process.

Appendix A - How to derive activation matrices from weight matrices

For the purpose of focusing on the core concepts of FlashAttention, the walkthrough focused exclusively on the core attention calculation step: softmax(QKT)V\text{softmax}(QK^T)V . This is the specific operation that FlashAttention optimizes.

This appendix provides the "prequel" step that was omitted to derive the activation matrices.

The Two Sets of Q, K, V

It's crucial to distinguish between:

  1. The Weight Matrices (W_Q, W_K, W_V): These are the trainable parameters learned during the training of the Transformer model. They are part of a standard Linear layer. Their job is to project the input token embeddings into the query, key, and value spaces. They are the same for every input sequence.

  2. The Activation Matrices (Q, K, V): These are the intermediate representations or activations. They are the result of multiplying the input embeddings by the weight matrices. These matrices are different for every input sequence and are the actual inputs to the attention calculation kernel that FlashAttention replaces. They are not trainable parameters themselves.

Think of it like a recipe:

  • The weight matrices (W_Q, W_K, W_V) are the instructions in the recipe book (fixed, learned).
  • The activation matrices (Q, K, V) are the actual ingredients you've prepared for one specific meal (changes every time you cook).

FlashAttention's innovation is in how to efficiently combine the prepared ingredients (Q, K, V), not in the initial preparation step itself.

The Omitted "Prequel" Step: Creating Q, K, and V

Here is the step that happens before our walkthrough begins.

Let's assume we have an input sequence of 6 tokens, and each token has an embedding dimension of 4. This is our input matrix X (not to be confused with the logit vector X from the walkthrough).

  • Input Embeddings X (size 6x4):
    X=[embedding for token 1embedding for token 2embedding for token 6]X = \begin{bmatrix} \text{embedding for token 1} \newline \text{embedding for token 2} \newline \vdots \newline \text{embedding for token 6} \end{bmatrix}

Now, let's define our trainable weight matrices. Let's say the head dimension d is 2. The weight matrices must project from the embedding dimension (4) to the head dimension (2). So, they will all be size 4x2.

  • Weight Matrices W_Q, W_K, W_V (size 4x2, trainable):
    WQ=[],WK=[],WV=[]W_Q = \begin{bmatrix} \dots \end{bmatrix}, \quad W_K = \begin{bmatrix} \dots \end{bmatrix}, \quad W_V = \begin{bmatrix} \dots \end{bmatrix}

The Q, K, and V activation matrices are created with standard matrix multiplication:

Q=XWQ(results in a 6×2 matrix)K=XWK(results in a 6×2 matrix)V=XWV(results in a 6×2 matrix) \begin{aligned} Q &= X \cdot W_Q \quad (\text{results in a } 6 \times 2 \text{ matrix}) \newline K &= X \cdot W_K \quad (\text{results in a } 6 \times 2 \text{ matrix}) \newline V &= X \cdot W_V \quad (\text{results in a } 6 \times 2 \text{ matrix}) \end{aligned}

The V matrix we get from this calculation is precisely the V matrix we used in the walkthrough:

V=[112233445566] V = \begin{bmatrix} 1 & 1 \newline 2 & 2 \newline 3 & 3 \newline 4 & 4 \newline 5 & 5 \newline 6 & 6 \end{bmatrix}

Connecting to the Logits

In our walkthrough, we focused on a single query vector q attending to all the keys. This corresponds to taking the first row of the Q matrix.

q=Q[0,:](the first row, size 1×2) q = Q[0, :] \quad (\text{the first row, size } 1 \times 2)

The logit vector X from the walkthrough would then be calculated as:

Logits=qKT \text{Logits} = q \cdot K^T

This multiplication would result in the 1 x 6 vector we started with:

Logits=[123621] \text{Logits} = \begin{bmatrix} 1 & 2 & 3 & 6 & 2 & 1 \end{bmatrix}

Summary

So, the full, un-omitted process is:

  1. Linear Projections (The Omitted Prequel):

    • Start with input embeddings X.
    • Compute Q=XWQQ = X \cdot W_Q , K=XWKK = X \cdot W_K , V=XWVV = X \cdot W_V using the trainable weight matrices. This is done with standard, highly optimized matrix multiplication libraries (GEMM).
  2. FlashAttention Calculation (The Walkthrough):

    • Take the resulting Q, K, and V activation matrices as input.
    • Efficiently compute O=softmax(QKT)VO = \text{softmax}(QK^T)V in a single kernel without materializing the full attention matrix.

Appendix B - The Paper's pseudocode vs. this walkthrough

The goal of my walkthrough was to make the calculation as intuitive as possible to follow by hand. To do this, I slightly rearranged the math to keep the intermediate numbers as simple as possible, while still being mathematically equivalent to the paper's algorithm.

1. The Paper's "Algorithm FlashAttention (Tiling)"

The algorithm in the PDF maintains a normalized output vector o' at every step. Let's look closely at the update rule for the output:

oioi1di1emi1midi+j=1bexi[j]miV[j+(i1)b,:]di o^\prime_i \leftarrow o^\prime_{i-1}\frac{d^\prime_{i-1}e^{m_{i-1}-m_i}}{d^\prime_i} + \frac{\sum_{j=1}^b e^{x_i[j]-m_i}V[j + (i-1)b, :]}{d^\prime_i}

Notice that both terms are divided by did^\prime_i , the new total denominator. This means that at the end of each iteration i, the vector o'_i is the correctly normalized attention output for all the data processed up to that tile.

2. My Walkthrough's Method

Doing this normalization (division) at every single step can be cumbersome for a manual example. It's easier to work with un-normalized sums and perform a single division at the very end.

So, this walkthrough maintains an un-normalized running output o_running. Our update rule was:

orunning=ooldemoldmnew+olocalo_{\text{running}} = o_{\text{old}} \cdot e^{m_{\text{old}} - m_{\text{new}}} + o_{\text{local}}

This is mathematically equivalent to the numerator of the paper's formula. You can see this if you multiply the paper's entire update rule by did^\prime_i :

oidi=(oi1di1emi1midi+V[]di)dioidi=(oi1di1)emi1mi+V[] \begin{aligned} o^\prime_i \cdot d^\prime_i &= \left( o^\prime_{i-1}\frac{d^\prime_{i-1}e^{m_{i-1}-m_i}}{d^\prime_i} + \frac{\sum \dots V[\dots]}{d^\prime_i} \right) \cdot d^\prime_i \newline o^\prime_i \cdot d^\prime_i &= (o^\prime_{i-1} d^\prime_{i-1}) e^{m_{i-1}-m_i} + \sum \dots V[\dots] \end{aligned}

If we define the orunningo_{\text{running}} as the paper's oidio^\prime_i \cdot d^\prime_i , then this equation is exactly the one used in the walkthrough:

  • My onewo_{\text{new}} is oidio^\prime_i \cdot d^\prime_i
  • My ooldo_{\text{old}} is oi1di1o^\prime_{i-1} \cdot d^\prime_{i-1}
  • My olocalo_{\text{local}} is V[]\sum \dots V[\dots]

The note was an attempt to explain this simplification: we chose to track the un-normalized numerator throughout the process for clarity and then perform the final division ofinal/dfinalo_{\text{final}} / d_{\text{final}} only once.

Why normalizing only at the last step works

The reason this works is that the update rules for the un-normalized numerator (o_running) and the denominator (d_running) are designed to maintain a consistent relationship. At every step i, the correctly normalized output is simply the ratio of the running numerator to the running denominator at that step. By deferring the division to the end, we arrive at the same final ratio. The proof by induction below demonstrates this formally.

The Two Methods

Let's formally define the two methods we are comparing. We will use the paper's notation where oio^\prime_i is the normalized output after tile i, and we'll introduce OiO_i (capital O) as our un-normalized running output numerator from the walkthrough.

Method 1: Normalize at Each Step (The Paper's Algorithm)
The state after tile i is defined by:

  1. di=di1emi1mi+(jTileiexjmi)d^\prime_i = d^\prime_{i-1} e^{m_{i-1}-m_i} + \left( \sum_{j \in \text{Tile}_i} e^{x_j-m_i} \right)
  2. oi=oi1di1emi1midi+jTileiexjmiVjdio^\prime_i = o^\prime_{i-1}\frac{d^\prime_{i-1}e^{m_{i-1}-m_i}}{d^\prime_i} + \frac{\sum_{j \in \text{Tile}_i} e^{x_j-m_i}V_j}{d^\prime_i}

The final result is oNo^\prime_{N} after the last tile N.

Method 2: Normalize Only at the End (The Walkthrough's Method)
The state after tile i is defined by an un-normalized numerator OiO_i and the same denominator did^\prime_i :

  1. di=di1emi1mi+(jTileiexjmi)d^\prime_i = d^\prime_{i-1} e^{m_{i-1}-m_i} + \left( \sum_{j \in \text{Tile}_i} e^{x_j-m_i} \right)
  2. Oi=Oi1emi1mi+jTileiexjmiVjO_i = O_{i-1} e^{m_{i-1}-m_i} + \sum_{j \in \text{Tile}_i} e^{x_j-m_i}V_j

The final result is calculated as ON/dNO_{N} / d^\prime_{N} at the very end.

The Proof of Equivalence

We want to prove that oN=ON/dNo^\prime_N = O_N / d^\prime_N . We can prove this by induction, showing that the relationship oi=Oi/dio^\prime_i = O_i / d^\prime_i holds true for every step i.

1. Base Case (i=1)

Let's check the first tile. Both methods start with d0=0d^\prime_0 = 0 , o0=0o^\prime_0 = \vec{0} , and O0=0O_0 = \vec{0} .

  • Method 1:

    o1=o0()0+jT1exjm1Vjd1=jT1exjm1Vjd1o^\prime_1 = \underbrace{o^\prime_0 \cdot (\dots)}{0} + \frac{\sum{j \in T_1} e^{x_j-m_1}V_j}{d^\prime_1} = \frac{\sum_{j \in T_1} e^{x_j-m_1}V_j}{d^\prime_1}
  • Method 2:

    O1=O0()0+jT1exjm1Vj=jT1exjm1VjO_1 = \underbrace{O_0 \cdot (\dots)}{0} + \sum{j \in T_1} e^{x_j-m_1}V_j = \sum_{j \in T_1} e^{x_j-m_1}V_j
    The final result would be O1/d1O_1 / d^\prime_1 .

Comparing the two, we see they are identical. The base case holds.

2. Inductive Hypothesis

Assume that the relationship is true for step i-1. That is, assume:

oi1=Oi1di1    Oi1=oi1di1 o^\prime_{i-1} = \frac{O_{i-1}}{d^\prime_{i-1}} \quad \implies \quad O_{i-1} = o^\prime_{i-1} d^\prime_{i-1}

3. Inductive Step

Now we must prove that the relationship holds for step i. Let's start with the formula for oio^\prime_i from Method 1 and show it equals Oi/diO_i / d^\prime_i .

Start with the definition of oio^\prime_i :

oi=oi1di1emi1midi+jTiexjmiVjdi o^\prime_i = o^\prime_{i-1}\frac{d^\prime_{i-1}e^{m_{i-1}-m_i}}{d^\prime_i} + \frac{\sum_{j \in T_i} e^{x_j-m_i}V_j}{d^\prime_i}

Let's combine the two fractions over the common denominator did^\prime_i :

oi=(oi1di1)emi1mi+jTiexjmiVjdi o^\prime_i = \frac{\left(o^\prime_{i-1} d^\prime_{i-1}\right) e^{m_{i-1}-m_i} + \sum_{j \in T_i} e^{x_j-m_i}V_j}{d^\prime_i}

Now, look at the term in the parentheses: (oi1di1)(o^\prime_{i-1} d^\prime_{i-1}) . According to our Inductive Hypothesis, this is exactly equal to Oi1O_{i-1} . Let's substitute it in:

oi=Oi1emi1mi+jTiexjmiVjdi o^\prime_i = \frac{O_{i-1} e^{m_{i-1}-m_i} + \sum_{j \in T_i} e^{x_j-m_i}V_j}{d^\prime_i}

Now, look at the entire numerator: Oi1emi1mi+jTiexjmiVjO_{i-1} e^{m_{i-1}-m_i} + \sum_{j \in T_i} e^{x_j-m_i}V_j . This is precisely the definition of OiO_i from Method 2.

So, we can substitute OiO_i for the numerator:

oi=Oidi o^\prime_i = \frac{O_i}{d^\prime_i}

This completes the proof. We have shown that if the relationship holds for step i-1, it must hold for step i. Since it holds for the base case i=1, it holds for all steps.

Intuitive Analogy: Calculating a Weighted Average

Think about calculating the final grade in a class. You have different assignments with different weights (scores).

  • Method 1 (Normalize at Each Step): After you get your first grade, you calculate your current average. When you get your second grade, you update your average based on the new grade and its weight relative to the first. You keep re-calculating your "running average" after every single assignment.

  • Method 2 (Normalize at the End): You collect all your points scored for each assignment (score * weight) in one column. You collect the total possible points (sum of all weights) in another column. You do this for the whole semester. At the very end, you do one single division: (total points scored) / (total possible points).

Both methods give you the exact same final grade. The second method simply defers the division. We can apply the same principle in FlashAttention: accumulate the un-normalized numerator (o_running) and the un-normalized denominator (d_running) separately and effectively perform the division at the end.

Top comments (0)