Table of Contents
- Why study online softmax? The key to flash attention
- Conceptual overview: tiled "safe" softmax
- Conceptual Overview: Online Softmax
- Mathematical conversion from element-wise to tile-wise formula
- Appendix A: A simple example of numerical stability
- Appendix B: Proof of correctness of online softmax
Why study online softmax? The key to flash attention
Softmax normalization is critical in the self-attention mechanism of transformers because it helps to manage "extreme values and offers more favorable gradient properties during training" (source: Building LLM from Scratch, chapter 3.3.1). For numerical stability, a "safe" softmax is widely used where in order to compute the final probability for any single element in every row of a matrix, we need to calculate two "global" statistics from the entirety of each row:
a. The maximum value of the row,
.
b. The normalization denominator,
, which is the sum over all elements in the row.
(Note: Please refer to Appendix A for an illustration of the importance of the absolute maximum value (max(x)
) for numerical stability.)
With a naive implementation of "safe" softmax, the two global statistics could potentially force a "multi-pass" approach where for each row, the following needs to be computed sequentially:
- Read the entire row to find the max.
- Calculate the denominator.
- Compute the softmax values.
With online softmax, steps 1 (find the max) and 2 (calculate the denominator) can be carried out within the same read of the matrix, instead of two sequential reads. Online softmax achieves this by calculating the necessary global statistics in a single pass over the data. This contrasts with a naive tiled approach that requires two passes to find the statistics, followed by a third pass for the final calculation. By reducing the number of memory passes from three to two, online softmax offers significant efficiency gains, especially since matrix operations in GPUs tend to be memory-bound.
This article provides an intuitive understanding of online softmax by demonstrating by hand how both the max value and denominator of the "safe" softmax can be calculated in a single pass. This walkthrough is based on the article Online normalizer calculation for softmax by Maxim Milakov and Natalia Gimelshein, 2018.
Online softmax also provides the basis for Flash Attention to fuse the entire attention calculation into a single GPU kernel, which I will discuss in my next article.
Conceptual overview: tiled "safe" softmax
The "safe" softmax is inherently a multi-pass algorithm. This is because calculating the final probability for any single element requires knowing two global properties of the entire input vector:
- The absolute maximum value (
max(x)
) for numerical stability. - The sum of all exponentiated values, which serves as the normalization denominator.
Tiling is a technique used to break down large matrices or vectors into smaller blocks, or "tiles". This is crucial for efficiency, especially on hardware like GPUs, as it allows smaller chunks of data to be loaded into fast, local memory (SRAM) for processing, rather than repeatedly accessing slower global memory (DRAM).
When applying tiling to the normal softmax, we process the input tile-by-tile. However, because we need the global maximum and the global sum, we cannot complete the calculation for any tile in a single go. This results in a multi-pass approach over the tiled data. (Note: The algorithm below describes an element-wise process, I will show its relationship to a tile-based approach.)
Let's illustrate this with a simple example. Consider the following 1x6 input vector (which can be thought of as a single row in a larger matrix):
We will process this vector with a tile size of 3. This breaks our input X
into two tiles:
The process follows the logic of "Algorithm 2: Safe softmax" from the image, but adapted for tiles.
Step 1: First Pass — Find the Global Maximum
In the first pass, we iterate through each tile to find its local maximum. Then, we find the maximum among all local maximums to determine the global maximum.
1a. Find Local Maximums
For each tile, we compute the maximum value, which we'll call m_local
.
-
For Tile 1 ( ):
-
For Tile 2 ( ):
1b. Find Global Maximum
Now, we find the maximum of the local maximums to get the global maximum, m_global
.
This value, m_global
, corresponds to m_V
in the provided "Algorithm 2".
Link back to the safe softmax algorithm
Line 1: m_0 ← -∞
- Explanation: This initializes the variable that will hold the maximum value. It's set to negative infinity to ensure the first element of the vector
x_1
becomes the first maximum. - In the Example: Before "Pass 1," we start with a conceptual maximum of .
Lines 2-4: for k ← 1, V do ... m_k ← max(m_{k-1}, x_k)
- Explanation: This is the first pass. The loop iterates through every element of the input vector
X
to find the single largest value. After the loop finishes,m_V
holds the global maximum. - In the Example: This corresponds to our "Step 1: First Pass".
- We first found the local max for Tile 1 ( ) and Tile 2 ( ).
- We then found the maximum of these, yielding the global max
m_global = 6
. This is the final value ofm_V
after this loop completes.
Step 2: Second Pass — Calculate the Global Denominator
In the second pass, we again iterate through each tile. This time, we use the m_global
to calculate a local sum of exponentials for each tile. These local sums are then aggregated to get the global denominator.
The formula for each element within a tile is .
2a. Calculate Local Denominators
- For Tile 1 ( ):
- For Tile 2 ( ):
2b. Calculate Global Denominator
The global denominator, d_global
, is the sum of all local denominators. This corresponds to d_V
in the algorithm.
Link back to the safe softmax algorithm
Line 5: d_0 ← 0
- Explanation: This initializes the variable for the denominator (the sum of exponentials).
- In the Example: Before "Pass 2," our sum is zero.
Lines 6-8: for j ← 1, V do ... d_j ← d_{j-1} + e^(x_j - m_V)
- Explanation: This is the second pass. The loop iterates through every element
x_j
again. For each element, it subtracts the global maximumm_V
(found in the first pass), exponentiates the result, and adds it to the running sum. After the loop,d_V
holds the global denominator. - In the Example: This corresponds to our "Step 2: Second Pass".
- For Tile 1, we calculated .
- For Tile 2, we calculated .
- The loop's final result,
d_V
, is the sum .
Step 3: Third Pass — Calculate Final Softmax Probabilities
In the final pass, we iterate through the tiles one last time. For each element x_i
, we compute its exponentiated value (scaled by m_global
) and then divide by the d_global
to get the final softmax probability y_i
.
The formula is:
3a. Compute Softmax for Each Tile
-
For Tile 1 ( ):
-
For Tile 2 ( ):
Link back to the safe softmax algorithm
Lines 9-11: for i ← 1, V do ... y_i ← e^(x_i - m_V) / d_V
- Explanation: This is the third pass. This final loop computes the softmax probability
y_i
for each element. It re-calculates the exponentiated value forx_i
(just as in the second pass) and then divides it by the global denominatord_V
. - In the Example: This corresponds to our "Step 3: Third Pass".
- ...and so on for all six elements.
Final Result
Combining the results from all tiles gives the final softmax output vector Y
:
The sum of these probabilities is approximately 1.0, as expected. This multi-pass process is the standard, "normal" way of computing softmax in a tiled fashion, and stands in contrast to the "online softmax" which is designed to reduce the number of passes over the data.
Conceptual Overview: Online Softmax
The "Online Softmax," as described in the paper by Milakov and Gimelshein, is a clever single-pass algorithm. Its key innovation is the ability to calculate the softmax probabilities tile-by-tile without needing to first compute the global maximum and global sum over the entire input. It achieves this by maintaining a running maximum and a running denominator. When a new tile is processed, these running statistics are updated.
The core of the method lies in how the running denominator is updated when a new, larger maximum value is found. It uses a "telescoping sum" property to rescale the previous sum, effectively correcting it to be consistent with the new maximum. This avoids the need for a second or third pass over the data, which is a significant advantage for memory-bound operations on hardware like GPUs. Proof that online softmax is correct is in Appendix B. Note that the original proof was found in Online normalizer calculation for softmax, pages 3 - 4, but I instead referenced Yi Wang's FlashAttention (Part 2): Online Softmax as I found his writing to be more intuitive. (Note: The algorithm below describes an element-wise process, I will show its relationship to a tile-based approach.)
Let's use the same example vector and tile size:
With a tile size of 3, we have two tiles:
The process will follow the logic of "Algorithm 3: Safe softmax with online normalizer calculation". We will process the vector tile by tile, from left to right, maintaining and updating two running values:
- : The running maximum found so far.
- : The running denominator (the sum of exponentials) scaled by the current .
Initialization
Before processing the first tile, we initialize our running statistics as per the algorithm:
These are represented in lines 1 and 2 of the online softmax algorithm.
Step 1: Process Tile 1 ( )
We now process the first tile. For this tile, we will compute a new maximum ( ) and a new denominator ( ).
1a. Find the New Maximum for the Tile
First, find the maximum value within the current tile and compare it to our running maximum.
- Local max of :
- The new overall maximum is:
1b. Update the Running Denominator
Next, we update the running denominator. The formula from Algorithm 3 (line 5) combines the previous denominator with the contribution from the current tile.
Let's apply this:
1c. Update Running Statistics
After processing the first tile, our running statistics are:
At this point, we could calculate intermediate (and incorrect) softmax values for the first tile using these stats, but the key insight of the online method is to wait until all tiles are processed. The final values will need to be rescaled.
Step 1: Process Tile 2 ( )
Now we move to the second tile and repeat the process, using the statistics from the previous step as our "old" values.
1d. Find the New Maximum
- Local max of :
- The new overall maximum is:
We have found a new, larger maximum. This will trigger the rescaling part of the update formula.
1e. Update the Running Denominator
We use the same formula as before, but now our old
values are the results from Tile 1.
The first term, , is the crucial rescaling step. It takes the previous sum, which was calculated relative to the old maximum of 3, and correctly scales it to be relative to the new maximum of 6.
1f. Update Final Running Statistics
After processing all tiles, our final global statistics are:
These values match the global maximum and denominator we found in the multi-pass "normal" softmax example.
Link back to the online softmax algorithm
Lines 3-6: for j ← 1, V do ...
Explanation: This is the main single pass of the algorithm. It iterates through the input vector element by element, updating the running max and denominator at each step.
In the Example: We processed this tile-by-tile, which is a block-wise version of this element-wise loop.
Line 4:
Explanation: At each step
j
, this line updates the running maximum.-
In the Example:
- When processing Tile 1 (
), the running max
m_j
becomes 1, then 2, then finally 3. After Tile 1, our example hadm_running = 3
. - When processing
, the max
max(3, 6)
becomes 6. Our example showed this change when moving from Tile 1 to Tile 2, wherem_new
became 6.
- When processing Tile 1 (
), the running max
Line 5:
-
Explanation: This is the core of the online update.
- The second term, , adds the contribution of the current element, scaled by the new running max.
- The first term, , is the crucial rescaling factor. If a new maximum is found ( ), this term correctly scales down the entire previous sum to be relative to the new, larger maximum. If the maximum doesn't change ( ), this term simply becomes , so we just keep adding to the sum.
-
In the Example:
- For Tile 1, the max was always updating, but started at 0. The final sum for the tile was .
- The most important step was processing Tile 2. Our running denominator was
d_old ≈ 1.5032
andm_old = 3
. The new max becamem_new = 6
. The update is the direct application of , rescaling the entire sum from the first tile. We then added the contributions from Tile 2's elements.
Mathematical conversion from element-wise to tile-wise formula
You may notice that the online softmax pseudocode does not have a summation Σ
which appears in the illustrative calculations. This is because the pseudocode describes the update rule on a per-element basis. The summation Σ
appears in the example because we are performing a per-tile update. The summation is simply the result of applying the single-element rule to every element within the tile at the same time. It is a practical optimization that is mathematically equivalent.
Let's see how the tile-wise formula is derived directly from the element-wise formula.
The core element-wise update rule is:
Imagine we have just finished processing Tile 1. Our running statistics are and . Now we want to process Tile 2, which contains the elements .
Instead of processing the whole tile at once, let's apply the element-wise rule three times in a row.
1. Processing the first element of the tile, :
2. Processing the second element, (using the results from step 1):
Let's substitute the expression for into this equation:
Notice the pattern: the original is now scaled by the newest max, and the contributions from the elements processed so far ( ) are summed up, also scaled by the newest max.
3. Processing the third element,
(using the results from step 2):
Substituting the expression for :
The Tile-Wise Optimization
In a practical implementation on parallel hardware (like a GPU), it's far more efficient to do the following:
- Load the entire tile into fast local memory.
- Find the local maximum for the tile: .
- Calculate the new overall maximum in one step:
.
Note that this final
m_new
is mathematically identical to them_new_c
we found after processing the last element in the step-by-step derivation above. - Apply the update rule in one go.
Looking at our final equation from the element-wise derivation:
We can rewrite the part in the parentheses using a summation:
If we substitute with our tile-wise , we get the exact formula used in the example:
Step 2: Final Pass — Calculate Final Softmax Probabilities
Because the online algorithm only makes a single pass to find the statistics of max and sum, a final pass is still required to compute the output probabilities using the final global statistics.
The formula is the same as in the normal softmax:
-
For Tile 1 ( ):
-
For Tile 2 ( ):
Link back to the online softmax algorithm
Lines 7-9: for i ← 1, V do ... y_i ← e^(x_i - m_V) / d_V
- Explanation: This is a final pass to compute the results. After the main loop (lines 3-6) is complete,
m_V
andd_V
hold the final, correct global statistics. This loop then uses those final values to calculate eachy_i
. - In the Example: This corresponds to our "Step 3: Final Pass". We took the final
m_final = 6
andd_final ≈ 1.0999
and applied them to all originalx_i
values to get the final probabilities, which is exactly what this loop does.
Final Result
The final softmax vector is identical to the one produced by the normal softmax method:
The critical advantage is that the online method calculated the necessary global statistics ( and ) in a single pass over the data, whereas the normal tiled method required two passes just to get the statistics, followed by a third pass for the final calculation. This reduction in memory passes provides for efficiency gains with online softmax.
Appendix A: A simple example of numerical stability
Let's illustrate why this stabilization is crucial with a simple vector, especially when using lower precision floating-point numbers like float16
, which is common in modern GPUs for accelerating deep learning workloads. The maximum value for a float16
is 65,504
.
Consider the input vector:
Case 1: The Naive (Unstable) Approach
Using the direct definition of softmax, we compute for each element:
Now, let's analyze this from a float16
perspective. With
, we proceed to calculate the denominator:
As the maximum finite positive value for a float16
is approximately 65,604
, the sum 162,816.777
exceeds the maximum value, and hence causes a catastrophic overflow in float16
, resulting in inf
(infinity). As a result, the final probabilities would be computed as:
- (Not a Number, from )
The result is a vector of [0, 0, NaN]
, which is completely incorrect and would destroy any subsequent training or inference.
Case 2: The Safe (Stable) Approach
Now, let's use the numerically stable formula.
Step 1: Find the maximum value, m
.
Step 2: Subtract m
from each element.
Step 3: Compute the exponentials of the new values.
Notice that all these values are small, well-behaved numbers between 0 and 1. There is absolutely no risk of overflow.
Step 4: Calculate the denominator, d'
.
Step 5: Compute the final probabilities.
The resulting softmax vector is approximately . The sum is 1.0, and we have avoided any numerical errors. This demonstrates that subtracting the maximum value is not just a theoretical trick—it is an essential step for making softmax work in practice.
Appendix B: Proof of correctness of online softmax
I have reproduced the proof by Yi Wang in FlashAttention (Part 2): Online Softmax below.
The definition of softmax is as follows:
If any is large (e.g. ), exceeds the maximum value of float16. To address this numerical instability, we compute an alternative form which gives equivalent result but numerically stable:
where . This form is safe because , ensuring that .
As mentioned in this article, online softmax seeks to allow the calculation of the max value and the denominator in parallel. More concretely, for any integer , we want to be able to:
Calculate an intermediate so that
Since is inductive, it should depend on .
To allow parallel execution, must not depend on future values such as or .
We begin by considering:
To ensure depends on , which is:
we need to split into two parts: one involving (which should not depend on or ), and the remaining terms that depend on and . The first step is straightforward – we separate the last term in the summation:
Now, only appears in the second term. However, still appears in the summation. Let's take the next step:
The expression inside the square brackets is exactly . Therefore, we have:
This allows us to compute inductively in parallel with .
Top comments (0)