DEV Community

Cover image for Online softmax by hand
Lewis Won
Lewis Won

Posted on • Edited on

Online softmax by hand

Table of Contents


Why study online softmax? The key to flash attention

Softmax normalization is critical in the self-attention mechanism of transformers because it helps to manage "extreme values and offers more favorable gradient properties during training" (source: Building LLM from Scratch, chapter 3.3.1). For numerical stability, a "safe" softmax is widely used where in order to compute the final probability for any single element in every row of a matrix, we need to calculate two "global" statistics from the entirety of each row:
a. The maximum value of the row, m=max(Si)m = \max(S_i) .
b. The normalization denominator, d=jexp(Sijm)d = \sum_j \exp(S_{ij} - m) , which is the sum over all elements in the row.

(Note: Please refer to Appendix A for an illustration of the importance of the absolute maximum value (max(x)) for numerical stability.)

With a naive implementation of "safe" softmax, the two global statistics could potentially force a "multi-pass" approach where for each row, the following needs to be computed sequentially:

  1. Read the entire row to find the max.
  2. Calculate the denominator.
  3. Compute the softmax values.

With online softmax, steps 1 (find the max) and 2 (calculate the denominator) can be carried out within the same read of the matrix, instead of two sequential reads. Online softmax achieves this by calculating the necessary global statistics in a single pass over the data. This contrasts with a naive tiled approach that requires two passes to find the statistics, followed by a third pass for the final calculation. By reducing the number of memory passes from three to two, online softmax offers significant efficiency gains, especially since matrix operations in GPUs tend to be memory-bound.

This article provides an intuitive understanding of online softmax by demonstrating by hand how both the max value and denominator of the "safe" softmax can be calculated in a single pass. This walkthrough is based on the article Online normalizer calculation for softmax by Maxim Milakov and Natalia Gimelshein, 2018.

Online softmax also provides the basis for Flash Attention to fuse the entire attention calculation into a single GPU kernel, which I will discuss in my next article.


Conceptual overview: tiled "safe" softmax

The "safe" softmax is inherently a multi-pass algorithm. This is because calculating the final probability for any single element requires knowing two global properties of the entire input vector:

  1. The absolute maximum value (max(x)) for numerical stability.
  2. The sum of all exponentiated values, which serves as the normalization denominator.

safe softmax

Tiling is a technique used to break down large matrices or vectors into smaller blocks, or "tiles". This is crucial for efficiency, especially on hardware like GPUs, as it allows smaller chunks of data to be loaded into fast, local memory (SRAM) for processing, rather than repeatedly accessing slower global memory (DRAM).

When applying tiling to the normal softmax, we process the input tile-by-tile. However, because we need the global maximum and the global sum, we cannot complete the calculation for any tile in a single go. This results in a multi-pass approach over the tiled data. (Note: The algorithm below describes an element-wise process, I will show its relationship to a tile-based approach.)

Algorithm for safe softmax

Let's illustrate this with a simple example. Consider the following 1x6 input vector (which can be thought of as a single row in a larger matrix):

X=[123621] X = \begin{bmatrix} 1 & 2 & 3 & 6 & 2 & 1 \end{bmatrix}

We will process this vector with a tile size of 3. This breaks our input X into two tiles:

T1=[123] T_1 = \begin{bmatrix} 1 & 2 & 3 \end{bmatrix}
T2=[621] T_2 = \begin{bmatrix} 6 & 2 & 1 \end{bmatrix}

The process follows the logic of "Algorithm 2: Safe softmax" from the image, but adapted for tiles.


Step 1: First Pass — Find the Global Maximum

In the first pass, we iterate through each tile to find its local maximum. Then, we find the maximum among all local maximums to determine the global maximum.

1a. Find Local Maximums

For each tile, we compute the maximum value, which we'll call m_local.

  • For Tile 1 ( T1T_1 ):

    m1=max([123])=3 m_1 = \max(\begin{bmatrix} 1 & 2 & 3 \end{bmatrix}) = 3
  • For Tile 2 ( T2T_2 ):

    m2=max([621])=6 m_2 = \max(\begin{bmatrix} 6 & 2 & 1 \end{bmatrix}) = 6

1b. Find Global Maximum

Now, we find the maximum of the local maximums to get the global maximum, m_global.

mglobal=max(m1,m2)=max(3,6)=6 m_{global} = \max(m_1, m_2) = \max(3, 6) = 6

This value, m_global, corresponds to m_V in the provided "Algorithm 2".

Link back to the safe softmax algorithm

Line 1: m_0 ← -∞

  • Explanation: This initializes the variable that will hold the maximum value. It's set to negative infinity to ensure the first element of the vector x_1 becomes the first maximum.
  • In the Example: Before "Pass 1," we start with a conceptual maximum of -\infty .

Lines 2-4: for k ← 1, V do ... m_k ← max(m_{k-1}, x_k)

  • Explanation: This is the first pass. The loop iterates through every element of the input vector X to find the single largest value. After the loop finishes, m_V holds the global maximum.
  • In the Example: This corresponds to our "Step 1: First Pass".
    • We first found the local max for Tile 1 ( m1=3m_1=3 ) and Tile 2 ( m2=6m_2=6 ).
    • We then found the maximum of these, yielding the global max m_global = 6. This is the final value of m_V after this loop completes.

Step 2: Second Pass — Calculate the Global Denominator

In the second pass, we again iterate through each tile. This time, we use the m_global to calculate a local sum of exponentials for each tile. These local sums are then aggregated to get the global denominator.

The formula for each element within a tile is e(xjmglobal)e^{(x_j - m_{global})} .

2a. Calculate Local Denominators

  • For Tile 1 ( T1=[123]T_1 = \begin{bmatrix} 1 & 2 & 3 \end{bmatrix} ):
d1=e(16)+e(26)+e(36)=e5+e4+e30.006738+0.018316+0.049787=0.074841 \begin{aligned} d_1 &= e^{(1 - 6)} + e^{(2 - 6)} + e^{(3 - 6)} \newline &= e^{-5} + e^{-4} + e^{-3} \newline &\approx 0.006738 + 0.018316 + 0.049787 \newline &= 0.074841 \newline \end{aligned}
  • For Tile 2 ( T2=[621]T_2 = \begin{bmatrix} 6 & 2 & 1 \end{bmatrix} ):
d2=e(66)+e(26)+e(16)=e0+e4+e51+0.018316+0.006738=1.025054 \begin{aligned} d_2 &= e^{(6 - 6)} + e^{(2 - 6)} + e^{(1 - 6)} \newline &= e^{0} + e^{-4} + e^{-5} \newline &\approx 1 + 0.018316 + 0.006738 \newline &= 1.025054 \newline \end{aligned}

2b. Calculate Global Denominator

The global denominator, d_global, is the sum of all local denominators. This corresponds to d_V in the algorithm.

dglobal=d1+d20.074841+1.025054=1.099895 \begin{aligned} d_{global} &= d_1 + d_2 \newline &\approx 0.074841 + 1.025054 \newline &= 1.099895 \end{aligned}

Link back to the safe softmax algorithm

Line 5: d_0 ← 0

  • Explanation: This initializes the variable for the denominator (the sum of exponentials).
  • In the Example: Before "Pass 2," our sum is zero.

Lines 6-8: for j ← 1, V do ... d_j ← d_{j-1} + e^(x_j - m_V)

  • Explanation: This is the second pass. The loop iterates through every element x_j again. For each element, it subtracts the global maximum m_V (found in the first pass), exponentiates the result, and adds it to the running sum. After the loop, d_V holds the global denominator.
  • In the Example: This corresponds to our "Step 2: Second Pass".
    • For Tile 1, we calculated d1=e(16)+e(26)+e(36)0.0748d_1 = e^{(1-6)} + e^{(2-6)} + e^{(3-6)} \approx 0.0748 .
    • For Tile 2, we calculated d2=e(66)+e(26)+e(16)1.0251d_2 = e^{(6-6)} + e^{(2-6)} + e^{(1-6)} \approx 1.0251 .
    • The loop's final result, d_V, is the sum d1+d21.0999d_1 + d_2 \approx 1.0999 .

Step 3: Third Pass — Calculate Final Softmax Probabilities

In the final pass, we iterate through the tiles one last time. For each element x_i, we compute its exponentiated value (scaled by m_global) and then divide by the d_global to get the final softmax probability y_i.

The formula is:

yi=e(ximglobal)dglobal y_i = \frac{e^{(x_i - m_{global})}}{d_{global}}

3a. Compute Softmax for Each Tile

  • For Tile 1 ( T1=[123]T_1 = \begin{bmatrix} 1 & 2 & 3 \end{bmatrix} ):

    • y1=e(16)1.099895=e51.0998950.0067381.0998950.006126y_1 = \frac{e^{(1 - 6)}}{1.099895} = \frac{e^{-5}}{1.099895} \approx \frac{0.006738}{1.099895} \approx 0.006126
    • y2=e(26)1.099895=e41.0998950.0183161.0998950.016652y_2 = \frac{e^{(2 - 6)}}{1.099895} = \frac{e^{-4}}{1.099895} \approx \frac{0.018316}{1.099895} \approx 0.016652
    • y3=e(36)1.099895=e31.0998950.0497871.0998950.045266y_3 = \frac{e^{(3 - 6)}}{1.099895} = \frac{e^{-3}}{1.099895} \approx \frac{0.049787}{1.099895} \approx 0.045266
  • For Tile 2 ( T2=[621]T_2 = \begin{bmatrix} 6 & 2 & 1 \end{bmatrix} ):

    • y4=e(66)1.099895=e01.09989511.0998950.909177y_4 = \frac{e^{(6 - 6)}}{1.099895} = \frac{e^{0}}{1.099895} \approx \frac{1}{1.099895} \approx 0.909177
    • y5=e(26)1.099895=e41.0998950.0183161.0998950.016652y_5 = \frac{e^{(2 - 6)}}{1.099895} = \frac{e^{-4}}{1.099895} \approx \frac{0.018316}{1.099895} \approx 0.016652
    • y6=e(16)1.099895=e51.0998950.0067381.0998950.006126y_6 = \frac{e^{(1 - 6)}}{1.099895} = \frac{e^{-5}}{1.099895} \approx \frac{0.006738}{1.099895} \approx 0.006126

Link back to the safe softmax algorithm

Lines 9-11: for i ← 1, V do ... y_i ← e^(x_i - m_V) / d_V

  • Explanation: This is the third pass. This final loop computes the softmax probability y_i for each element. It re-calculates the exponentiated value for x_i (just as in the second pass) and then divides it by the global denominator d_V.
  • In the Example: This corresponds to our "Step 3: Third Pass".
    • y1=e(16)/1.09990.0061y_1 = e^{(1-6)} / 1.0999 \approx 0.0061
    • y2=e(26)/1.09990.0167y_2 = e^{(2-6)} / 1.0999 \approx 0.0167
    • ...and so on for all six elements.

Final Result

Combining the results from all tiles gives the final softmax output vector Y:

Y[0.00610.01670.04530.90920.01670.0061] Y \approx \begin{bmatrix} 0.0061 & 0.0167 & 0.0453 & 0.9092 & 0.0167 & 0.0061 \end{bmatrix}

The sum of these probabilities is approximately 1.0, as expected. This multi-pass process is the standard, "normal" way of computing softmax in a tiled fashion, and stands in contrast to the "online softmax" which is designed to reduce the number of passes over the data.


Conceptual Overview: Online Softmax

The "Online Softmax," as described in the paper by Milakov and Gimelshein, is a clever single-pass algorithm. Its key innovation is the ability to calculate the softmax probabilities tile-by-tile without needing to first compute the global maximum and global sum over the entire input. It achieves this by maintaining a running maximum and a running denominator. When a new tile is processed, these running statistics are updated.

The core of the method lies in how the running denominator is updated when a new, larger maximum value is found. It uses a "telescoping sum" property to rescale the previous sum, effectively correcting it to be consistent with the new maximum. This avoids the need for a second or third pass over the data, which is a significant advantage for memory-bound operations on hardware like GPUs. Proof that online softmax is correct is in Appendix B. Note that the original proof was found in Online normalizer calculation for softmax, pages 3 - 4, but I instead referenced Yi Wang's FlashAttention (Part 2): Online Softmax as I found his writing to be more intuitive. (Note: The algorithm below describes an element-wise process, I will show its relationship to a tile-based approach.)

online softmax algorithm

Let's use the same example vector and tile size:

X=[123621] X = \begin{bmatrix} 1 & 2 & 3 & 6 & 2 & 1 \end{bmatrix}

With a tile size of 3, we have two tiles:

T1=[123] T_1 = \begin{bmatrix} 1 & 2 & 3 \end{bmatrix}
T2=[621] T_2 = \begin{bmatrix} 6 & 2 & 1 \end{bmatrix}

The process will follow the logic of "Algorithm 3: Safe softmax with online normalizer calculation". We will process the vector tile by tile, from left to right, maintaining and updating two running values:

  • mrunningm_{running} : The running maximum found so far.
  • drunningd_{running} : The running denominator (the sum of exponentials) scaled by the current mrunningm_{running} .

Initialization

Before processing the first tile, we initialize our running statistics as per the algorithm:

  • m0=mrunning=m_0 = m_{running} = -\infty
  • d0=drunning=0d_0 = d_{running} = 0

These are represented in lines 1 and 2 of the online softmax algorithm.

Step 1: Process Tile 1 ( T1=[123]T_1 = \begin{bmatrix} 1 & 2 & 3 \end{bmatrix} )

We now process the first tile. For this tile, we will compute a new maximum ( mnewm_{new} ) and a new denominator ( dnewd_{new} ).

1a. Find the New Maximum for the Tile

First, find the maximum value within the current tile and compare it to our running maximum.

  • Local max of T1T_1 :
mT1=max([123])=3 m_{T_1} = \max(\begin{bmatrix} 1 & 2 & 3 \end{bmatrix}) = 3
  • The new overall maximum is:
mnew=max(mrunning,mT1)=max(,3)=3 m_{new} = \max(m_{running}, m_{T_1}) = \max(-\infty, 3) = 3

1b. Update the Running Denominator

Next, we update the running denominator. The formula from Algorithm 3 (line 5) combines the previous denominator with the contribution from the current tile.

dnewdold×e(moldmnew)+xjcurrent tilee(xjmnew) d_{new} \leftarrow d_{old} \times e^{(m_{old} - m_{new})} + \sum_{x_j \in \text{current tile}} e^{(x_j - m_{new})}

Let's apply this:

  • mold=mrunning=m_{old} = m_{running} = -\infty
  • dold=drunning=0d_{old} = d_{running} = 0
  • mnew=3m_{new} = 3
dnew=(0×e(3))+(e(13)+e(23)+e(33))=0+(e2+e1+e0)0.1353+0.3679+1=1.5032 \begin{aligned} d_{new} &= (0 \times e^{(-\infty - 3)}) + (e^{(1 - 3)} + e^{(2 - 3)} + e^{(3 - 3)}) \newline &= 0 + (e^{-2} + e^{-1} + e^{0}) \newline &\approx 0.1353 + 0.3679 + 1 = 1.5032 \end{aligned}

1c. Update Running Statistics

After processing the first tile, our running statistics are:

  • mrunning=3m_{running} = 3
  • drunning1.5032d_{running} \approx 1.5032

At this point, we could calculate intermediate (and incorrect) softmax values for the first tile using these stats, but the key insight of the online method is to wait until all tiles are processed. The final values will need to be rescaled.


Step 1: Process Tile 2 ( T2=[621]T_2 = \begin{bmatrix} 6 & 2 & 1 \end{bmatrix} )

Now we move to the second tile and repeat the process, using the statistics from the previous step as our "old" values.

1d. Find the New Maximum

  • Local max of T2T_2 :
mT2=max([621])=6 m_{T_2} = \max(\begin{bmatrix} 6 & 2 & 1 \end{bmatrix}) = 6
  • The new overall maximum is: mnew=max(mrunning,mT2)=max(3,6)=6m_{new} = \max(m_{running}, m_{T_2}) = \max(3, 6) = 6

We have found a new, larger maximum. This will trigger the rescaling part of the update formula.

1e. Update the Running Denominator

We use the same formula as before, but now our old values are the results from Tile 1.

  • mold=3m_{old} = 3
  • dold1.5032d_{old} \approx 1.5032
  • mnew=6m_{new} = 6
dnew=(dold×e(moldmnew))+(xjT2e(xjmnew))(1.5032×e(36))+(e(66)+e(26)+e(16))(1.5032×e3)+(e0+e4+e5)(1.5032×0.04979)+(1+0.01832+0.00674)0.07484+1.02506=1.0999 \begin{aligned} d_{new} &= (d_{old} \times e^{(m_{old} - m_{new})}) + (\sum_{x_j \in T_2} e^{(x_j - m_{new})}) \newline &\approx (1.5032 \times e^{(3 - 6)}) + (e^{(6 - 6)} + e^{(2 - 6)} + e^{(1 - 6)}) \newline &\approx (1.5032 \times e^{-3}) + (e^{0} + e^{-4} + e^{-5}) \newline &\approx (1.5032 \times 0.04979) + (1 + 0.01832 + 0.00674) \newline &\approx 0.07484 + 1.02506 \newline &= 1.0999 \end{aligned}

The first term, 1.5032×e(36)1.5032 \times e^{(3 - 6)} , is the crucial rescaling step. It takes the previous sum, which was calculated relative to the old maximum of 3, and correctly scales it to be relative to the new maximum of 6.

1f. Update Final Running Statistics

After processing all tiles, our final global statistics are:

  • mfinal=6m_{final} = 6
  • dfinal1.0999d_{final} \approx 1.0999

These values match the global maximum and denominator we found in the multi-pass "normal" softmax example.

Link back to the online softmax algorithm

Lines 3-6: for j ← 1, V do ...

  • Explanation: This is the main single pass of the algorithm. It iterates through the input vector element by element, updating the running max and denominator at each step.

  • In the Example: We processed this tile-by-tile, which is a block-wise version of this element-wise loop.

Line 4: mjmax(mj1,xj)m_j ← max(m_{j-1}, x_j)

  • Explanation: At each step j, this line updates the running maximum.

  • In the Example:

    • When processing Tile 1 ( x1,x2,x3x_1, x_2, x_3 ), the running max m_j becomes 1, then 2, then finally 3. After Tile 1, our example had m_running = 3.
    • When processing x4=6x_4=6 , the max max(3, 6) becomes 6. Our example showed this change when moving from Tile 1 to Tile 2, where m_new became 6.

Line 5: djdj1×e(mj1mj)+e(xjmj)d_j ← d_{j-1} × e^{(m_{j-1} - m_j)} + e^{(x_j - m_j)}

  • Explanation: This is the core of the online update.

    • The second term, e(xjmj)e^{(x_j - m_j)} , adds the contribution of the current element, scaled by the new running max.
    • The first term, dj1e(mj1mj)d_{j-1} * e^{(m_{j-1} - m_j)} , is the crucial rescaling factor. If a new maximum is found ( mj>mj1m_j > m_{j-1} ), this term correctly scales down the entire previous sum to be relative to the new, larger maximum. If the maximum doesn't change ( mj=mj1m_j = m_{j-1} ), this term simply becomes dj1×e0=dj1d_{j-1} × e^0 = d_{j-1} , so we just keep adding to the sum.
  • In the Example:

    • For Tile 1, the max was always updating, but dj1d_{j-1} started at 0. The final sum for the tile was d1=e(13)+e(23)+e(33)1.5032d_1 = e^{(1-3)} + e^{(2-3)} + e^{(3-3)} \approx 1.5032 .
    • The most important step was processing Tile 2. Our running denominator was d_old ≈ 1.5032 and m_old = 3. The new max became m_new = 6. The update (1.5032×e(36))(1.5032 × e^{(3 - 6)}) is the direct application of dj1×e(mj1mj)d_{j-1} × e^{(m_{j-1} - m_j)} , rescaling the entire sum from the first tile. We then added the contributions from Tile 2's elements.

Mathematical conversion from element-wise to tile-wise formula

You may notice that the online softmax pseudocode does not have a summation Σ which appears in the illustrative calculations. This is because the pseudocode describes the update rule on a per-element basis. The summation Σ appears in the example because we are performing a per-tile update. The summation is simply the result of applying the single-element rule to every element within the tile at the same time. It is a practical optimization that is mathematically equivalent.

Let's see how the tile-wise formula is derived directly from the element-wise formula.

The core element-wise update rule is:

dj=dj1emj1mj+exjmj d_j = d_{j-1} \cdot e^{m_{j-1} - m_j} + e^{x_j - m_j}

Imagine we have just finished processing Tile 1. Our running statistics are moldm_{old} and doldd_{old} . Now we want to process Tile 2, which contains the elements xa,xb,xc{x_a, x_b, x_c} .

Instead of processing the whole tile at once, let's apply the element-wise rule three times in a row.

1. Processing the first element of the tile, xax_a :

mnewa=max(mold,xa) m_{new_a} = \max(m_{old}, x_a)
dnewa=doldemoldmnewa+examnewa d_{new_a} = d_{old} \cdot e^{m_{old} - m_{new_a}} + e^{x_a - m_{new_a}}

2. Processing the second element, xbx_b (using the results from step 1):

mnewb=max(mnewa,xb) m_{new_b} = \max(m_{new_a}, x_b)
dnewb=dnewaemnewamnewb+exbmnewb d_{new_b} = d_{new_a} \cdot e^{m_{new_a} - m_{new_b}} + e^{x_b - m_{new_b}}

Let's substitute the expression for dnewad_{new_a} into this equation:

dnewb=(doldemoldmnewa+examnewa)emnewamnewb+exbmnewb d_{new_b} = (d_{old} \cdot e^{m_{old} - m_{new_a}} + e^{x_a - m_{new_a}}) \cdot e^{m_{new_a} - m_{new_b}} + e^{x_b - m_{new_b}}
dnewb=doldemoldmnewb+examnewb+exbmnewb d_{new_b} = d_{old} \cdot e^{m_{old} - m_{new_b}} + e^{x_a - m_{new_b}} + e^{x_b - m_{new_b}}

Notice the pattern: the original doldd_{old} is now scaled by the newest max, and the contributions from the elements processed so far ( xa,xbx_a, x_b ) are summed up, also scaled by the newest max.

3. Processing the third element, xcx_c (using the results from step 2):

mnewc=max(mnewb,xc) m_{new_c} = \max(m_{new_b}, x_c)
dnewc=dnewbemnewbmnewc+excmnewc d_{new_c} = d_{new_b} \cdot e^{m_{new_b} - m_{new_c}} + e^{x_c - m_{new_c}}

Substituting the expression for dnewbd_{new_b} :

dnewc=(doldemoldmnewb+examnewb+exbmnewb)emnewbmnewc+excmnewc d_{new_c} = (d_{old} \cdot e^{m_{old} - m_{new_b}} + e^{x_a - m_{new_b}} + e^{x_b - m_{new_b}}) \cdot e^{m_{new_b} - m_{new_c}} + e^{x_c - m_{new_c}}
dnewc=doldemoldmnewc+examnewc+exbmnewc+excmnewc d_{new_c} = d_{old} \cdot e^{m_{old} - m_{new_c}} + e^{x_a - m_{new_c}} + e^{x_b - m_{new_c}} + e^{x_c - m_{new_c}}

The Tile-Wise Optimization

In a practical implementation on parallel hardware (like a GPU), it's far more efficient to do the following:

  1. Load the entire tile xa,xb,xc{x_a, x_b, x_c} into fast local memory.
  2. Find the local maximum for the tile: mtile=max(xa,xb,xc)m_{tile} = \max(x_a, x_b, x_c) .
  3. Calculate the new overall maximum in one step: mnew=max(mold,mtile)m_{new} = \max(m_{old}, m_{tile}) . Note that this final m_new is mathematically identical to the m_new_c we found after processing the last element in the step-by-step derivation above.
  4. Apply the update rule in one go.

Looking at our final equation from the element-wise derivation:

dnewc=doldemoldmnewc+(examnewc+exbmnewc+excmnewc) d_{new_c} = d_{old} \cdot e^{m_{old} - m_{new_c}} + (e^{x_a - m_{new_c}} + e^{x_b - m_{new_c}} + e^{x_c - m_{new_c}})

We can rewrite the part in the parentheses using a summation:

xjTileexjmnewc \sum_{x_j \in \text{Tile}} e^{x_j - m_{new_c}}

If we substitute mnewcm_{new_c} with our tile-wise mnewm_{new} , we get the exact formula used in the example:

dnew=doldemoldmnew+xjTileexjmnew d_{new} = d_{old} \cdot e^{m_{old} - m_{new}} + \sum_{x_j \in \text{Tile}} e^{x_j - m_{new}}

Step 2: Final Pass — Calculate Final Softmax Probabilities

Because the online algorithm only makes a single pass to find the statistics of max and sum, a final pass is still required to compute the output probabilities using the final global statistics.

The formula is the same as in the normal softmax: yi=e(ximfinal)dfinaly_i = \frac{e^{(x_i - m_{final})}}{d_{final}}

  • For Tile 1 ( T1=[123]T_1 = \begin{bmatrix} 1 & 2 & 3 \end{bmatrix} ):

    • y1=e(16)1.0999=e51.09990.0061y_1 = \frac{e^{(1 - 6)}}{1.0999} = \frac{e^{-5}}{1.0999} \approx 0.0061
    • y2=e(26)1.0999=e41.09990.0167y_2 = \frac{e^{(2 - 6)}}{1.0999} = \frac{e^{-4}}{1.0999} \approx 0.0167
    • y3=e(36)1.0999=e31.09990.0453y_3 = \frac{e^{(3 - 6)}}{1.0999} = \frac{e^{-3}}{1.0999} \approx 0.0453
  • For Tile 2 ( T2=[621]T_2 = \begin{bmatrix} 6 & 2 & 1 \end{bmatrix} ):

    • y4=e(66)1.0999=e01.09990.9092y_4 = \frac{e^{(6 - 6)}}{1.0999} = \frac{e^{0}}{1.0999} \approx 0.9092
    • y5=e(26)1.0999=e41.09990.0167y_5 = \frac{e^{(2 - 6)}}{1.0999} = \frac{e^{-4}}{1.0999} \approx 0.0167
    • y6=e(16)1.0999=e51.09990.0061y_6 = \frac{e^{(1 - 6)}}{1.0999} = \frac{e^{-5}}{1.0999} \approx 0.0061

Link back to the online softmax algorithm

Lines 7-9: for i ← 1, V do ... y_i ← e^(x_i - m_V) / d_V

  • Explanation: This is a final pass to compute the results. After the main loop (lines 3-6) is complete, m_V and d_V hold the final, correct global statistics. This loop then uses those final values to calculate each y_i.
  • In the Example: This corresponds to our "Step 3: Final Pass". We took the final m_final = 6 and d_final ≈ 1.0999 and applied them to all original x_i values to get the final probabilities, which is exactly what this loop does.

Final Result

The final softmax vector is identical to the one produced by the normal softmax method:

Y[0.00610.01670.04530.90920.01670.0061]Y \approx \begin{bmatrix} 0.0061 & 0.0167 & 0.0453 & 0.9092 & 0.0167 & 0.0061 \end{bmatrix}

The critical advantage is that the online method calculated the necessary global statistics ( mfinalm_{final} and dfinald_{final} ) in a single pass over the data, whereas the normal tiled method required two passes just to get the statistics, followed by a third pass for the final calculation. This reduction in memory passes provides for efficiency gains with online softmax.


Appendix A: A simple example of numerical stability

Let's illustrate why this stabilization is crucial with a simple vector, especially when using lower precision floating-point numbers like float16, which is common in modern GPUs for accelerating deep learning workloads. The maximum value for a float16 is 65,504.

Consider the input vector:

X=[2412] X = \begin{bmatrix} 2 & 4 & 12 \end{bmatrix}

Case 1: The Naive (Unstable) Approach

Using the direct definition of softmax, we compute exp(xi)\exp(x_i) for each element:

  • exp(2)7.389\exp(2) \approx 7.389
  • exp(4)54.598\exp(4) \approx 54.598
  • exp(12)162,754.79\exp(12) \approx 162,754.79

Now, let's analyze this from a float16 perspective. With X=[1212]X = \begin{bmatrix} 1 & 2 & 12 \end{bmatrix} , we proceed to calculate the denominator:
d=j=1Nexp(xj)7.389+54.598+162,754.79=162,816.777d = \sum_{j=1}^N \exp(x_j) \approx 7.389 + 54.598 + 162,754.79 = 162,816.777

As the maximum finite positive value for a float16 is approximately 65,604, the sum 162,816.777 exceeds the maximum value, and hence causes a catastrophic overflow in float16, resulting in inf (infinity). As a result, the final probabilities would be computed as:

  • y1=7.389=0y_1 = \frac{7.389}{\infty} = 0
  • y2=54.598=0y_2 = \frac{54.598}{\infty} = 0
  • y3=162,754.79=NaNy_3 = \frac{162,754.79}{\infty} = \text{NaN} (Not a Number, from \frac{\infty}{\infty} )

The result is a vector of [0, 0, NaN], which is completely incorrect and would destroy any subsequent training or inference.


Case 2: The Safe (Stable) Approach

Now, let's use the numerically stable formula.

Step 1: Find the maximum value, m.

m=max([1212])=12 m = \max(\begin{bmatrix} 1 & 2 & 12 \end{bmatrix}) = 12

Step 2: Subtract m from each element.

Xm=[1122121212]=[11100] X - m = \begin{bmatrix} 1-12 & 2-12 & 12-12 \end{bmatrix} = \begin{bmatrix} -11 & -10 & 0 \end{bmatrix}

Step 3: Compute the exponentials of the new values.

  • exp(11)0.0000167\exp(-11) \approx 0.0000167
  • exp(10)0.0000454\exp(-10) \approx 0.0000454
  • exp(0)=1\exp(0) = 1

Notice that all these values are small, well-behaved numbers between 0 and 1. There is absolutely no risk of overflow.

Step 4: Calculate the denominator, d'.

d=j=1Nexp(xjm)0.0000167+0.0000454+1=1.0000621 d' = \sum_{j=1}^N \exp(x_j - m) \approx 0.0000167 + 0.0000454 + 1 = 1.0000621

Step 5: Compute the final probabilities.

  • y1=exp(11)1.00006210.00001671.00006210.0000167y_1 = \frac{\exp(-11)}{1.0000621} \approx \frac{0.0000167}{1.0000621} \approx 0.0000167
  • y2=exp(10)1.00006210.00004541.00006210.0000454y_2 = \frac{\exp(-10)}{1.0000621} \approx \frac{0.0000454}{1.0000621} \approx 0.0000454
  • y3=exp(0)1.0000621=11.00006210.9999379y_3 = \frac{\exp(0)}{1.0000621} = \frac{1}{1.0000621} \approx 0.9999379

The resulting softmax vector is approximately [0.0000170.0000450.999938]\begin{bmatrix} 0.000017 & 0.000045 & 0.999938 \end{bmatrix} . The sum is 1.0, and we have avoided any numerical errors. This demonstrates that subtracting the maximum value is not just a theoretical trick—it is an essential step for making softmax work in practice.


Appendix B: Proof of correctness of online softmax

I have reproduced the proof by Yi Wang in FlashAttention (Part 2): Online Softmax below.

The definition of softmax is as follows:

{exp(xi)j=1Nexp(xj)}i=1N \left\{ \frac{\exp(x_i)}{\sum_{j=1}^N \exp(x_j)} \right\}_{i=1}^N

If any xix_i is large (e.g. xi11x_i \geq 11 ), exp(xi)\exp(x_i) exceeds the maximum value of float16. To address this numerical instability, we compute an alternative form which gives equivalent result but numerically stable:

{exp(xim)j=1Nexp(xjm)}i=1N \left\{ \frac{\exp(x_i - m)}{\sum_{j=1}^N \exp(x_j - m)} \right\}_{i=1}^N

where m=maxj=1Nxjm = \max_{j=1}^N x_j . This form is safe because xim0x_i - m \le 0 , ensuring that 0<exp(xim)10 < \exp(x_i - m) \le 1 .

As mentioned in this article, online softmax seeks to allow the calculation of the max value mm and the denominator j=1Nexp(xjm)\sum_{j=1}^N \exp(x_j - m) in parallel. More concretely, for any integer 1iN1 \leq i \leq N , we want to be able to:

  1. Calculate an intermediate δi=j=1iexp(xjmi)\delta_i = \sum_{j=1}^i \exp(x_j - m_i) so that δN=j=1Nexp(xjmN)\delta_N = \sum_{j=1}^N \exp(x_j - m_N)

  2. Since δi\delta_i is inductive, it should depend on δi1\delta_{i-1} .

  3. To allow parallel execution, δi\delta_i must not depend on future values such as xi+1,x_{i+1}, \dots or mi+1,m_{i+1}, \dots .

We begin by considering:

δi=j=1iexp(xjmi) \delta_i = \sum_{j=1}^i \exp(x_j - m_i)

To ensure δi\delta_i depends on δi1\delta_{i-1} , which is:

δi1=j=1i1exp(xjmi1) \delta_{i-1} = \sum_{j=1}^{i-1} \exp(x_j - m_{i-1})

we need to split δi\delta_i into two parts: one involving δi1\delta_{i-1} (which should not depend on xix_i or mim_i ), and the remaining terms that depend on xix_i and mim_i . The first step is straightforward – we separate the last term in the summation:

δi=j=1i1exp(xjmi)+exp(ximi) \delta_i = \sum_{j=1}^{i-1} \exp(x_j - m_i) + \exp(x_i - m_i)

Now, xix_i only appears in the second term. However, mim_i still appears in the summation. Let's take the next step:

δi=j=1i1exp(xjmi1+mi1mi)+exp(ximi) =[j=1i1exp(xjmi1)]exp(mi1mi)+exp(ximi) \delta_i = \sum_{j=1}^{i-1} \exp(x_j - m_{i-1} + m_{i-1} - m_i) + \exp(x_i - m_i) \ = \left[ \sum_{j=1}^{i-1} \exp(x_j - m_{i-1}) \right] \exp(m_{i-1} - m_i) + \exp(x_i - m_i)

The expression inside the square brackets is exactly δi1\delta_{i-1} . Therefore, we have:

δi=δi1exp(mi1mi)+exp(ximi) \delta_i = \delta_{i-1} \exp(m_{i-1} - m_i) + \exp(x_i - m_i)

This allows us to compute δi\delta_i inductively in parallel with mim_i .

Top comments (0)