DEV Community

Cover image for 什么是Online Softmax and Flash Attention?
Jim Wang
Jim Wang

Posted on

什么是Online Softmax and Flash Attention?

Softmax是Transformer模型架构中非常重要的一环。它所在的Attention模块虽然所需要的计算量不大,但也是不容忽视的一环。同时由于它本身的数学特性所造成的数据依赖,如果按照其原始方法来进行运算,会耗费大量的计算时间,因为它需要三次完整读取数据。

Online normalizer calculation for softmax 提出了online softmax,通过牺牲计算来节省数据读取次数,将三遍完整读取(3 passes)降低到两遍完整读取(2 passes)。

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 则应用了类似的思想,更进一步,利用NVIDIA GPU的本地存储,将读取次数减少为一遍。

那么就让我们详细了解一下其中奥秘吧。本文会出现数学公式,但不要慌张,仅仅是简单的数组知识而已。同时也会辅以可以运行的Python代码,以便理解。

Softmax的数学表达

首先,softmax作用在一维vector上,而非多维tensor。

softmax(X[1:N])=exij1Nexj1 softmax( \vec{X} [1:N] ) = \frac {e^{x_i}} {\sum_{j-1}^{N} e^{x_j} }(1)

我们需要(1)先计算每个xe^x并进行加总得到总和,(2)然后计算除法。

Safe softmax

但是由于e^x的特性,当x相对较大时,e^x就容易溢出,尤其是使用float16甚至更低精度的浮点表达时。因此我们需要对X进行normalization(归一化),并称之为saft softmax。

softmax(X[1:N])=eximax(X)j1Nexjmax(X)2 softmax( \vec{X} [1:N] ) = \frac {e^{x_i - \max(\vec{X})}} {\sum_{j-1}^{N} e^{x_j - \max(\vec{X})} } (2)

因此,softmax的三遍读取为(1)统计最大值(2)加总(3)除法

如果用Python来表示的话:

vec_x = [random.random() for _ in range(N)]
# 1st pass
max_x = max(vec_x)
# 2nd pass
sum_ex = sum([math.exp(x - max_x) for x in vec_x])
# 3rd pass
softmax_x = [math.exp(x - max_x) / sum_ex for x in vec_x]
Enter fullscreen mode Exit fullscreen mode

Online softmax

但是,我们其实可以将(1)和(2)合并,在寻找最大值的同时进行加总操作。

如果我们能够构造两个数组MS,它们的单元数量和X一样:

mi=max(x[1:i])si=j=1iexjmi3 m_i = \max(\vec{x}[1:i]) \newline s_i = \sum_{j=1}^{i}{e^{x_j - m_i}} (3)

即,数组M保存的是数组X到当前index为止的最大值,而数组S保存的是类似softmax的加总(除数)。数组S的所有数值都和softmax的第二步加总的出来的结果的不一样,除了最后一个。也就是说,如果我们一个个读取数组X,并计算数组M和数组S,在计算到最后一个的时候,数组Sd_N正好等于softmax的除数。这样我们只需要一遍读取就能完成步骤(1)和步骤(2)。我们选择符号m来表示max最大值;选择符号s来表示sum加总。

这种奇思妙想不得不令我们想起dynamic programming。

如果我们展开数组S的计算过程:

s1=14 s_1 = 1 (4)
si=si1×emi1mi+eximi5 s_i = s_{i-1} \times e^{m_{i-1} - m_i} + e^{x_i - m_i} (5)
sN=j=1NexjmN6 s_N = \sum_{j=1}^{N} e^{x_j - m_N} (6)

我们可以证明公式6,即当i=N的时候,s_i即为softmax中加总之后的结果。

sN=sN1×emN1mN+exNmN=j=1N1(exjmN1×emN1mN)+exNmN=j=1N1exjmN+exNmN=j=1NexjmN7 s_N = s_{N-1} \times e^{m_{N-1} - m_N} + e^{x_N - m_N} \newline = \sum_{j=1}^{N-1}{(e^{x_j - m_{N-1}} \times e^{m_{N-1} - m_N})} + e^{x_N - m_N} \newline = \sum_{j=1}^{N-1}{e^{x_j - m_{N}} } + e^{x_N - m_N} \newline = \sum_{j=1}^{N} e^{x_j - m_N} (7)

如果用Python来表示的话:

vec_x = [random.random() for _ in range(N)]

s_prev = 1 # previous s_{i-1}
m_prev = vec_x[0] # previous m_{i-1}
for i in range(1, N): # starting from the 2nd element
  x = vec_x[i] # read x
  m = max(m_prev, x) # calculate m_i
  s = s_prev * math.exp(m_prev - m) + math.exp(x - m) # calculate s_i
  s_prev = s # save as s_{i-1}
Enter fullscreen mode Exit fullscreen mode

Attention的数学表达

A=softmax(mask(QKTK))V8 \mathbb{A} = softmax(mask(\frac{\mathbb{Q} \cdot \mathbb{K}^T}{\sqrt{K}})) \cdot \mathbb{V} (8)

其中A, Q, K, V都是二维matrix,A, Q, K, V的shape分别依次是[N,K], [N,K], [M,K], [M,K]。但对于Transformer架构而言,N == Mmask() 函数旨在加载self-attention或者其他种类的masking,但对masking的探讨超越了本文范围,且因为食element-wise操作,所以不改变本文讨论的内容。

为什么要除以常数√K?
同样是为了归一化,这样在进行 Q·Kᵀ 时(1)不会溢出(2)分布过于极端导致softmax溢出(3)在训练过程中造成不稳定。

简化11/√K scaling通常是在进行点乘运算前对Q或者K进行,因此在后面的讨论中不再赘述。

简化2:以后的讨论我们还会简化masking的步骤。

经过简化,同时寻找X和Y符号来表示中间步骤的变量,attention可以表达为:

softmax(QKT)V=softmax(X)V=YV=A9 softmax(\mathbb{Q} \cdot \mathbb{K}^T) \cdot \mathbb{V} = softmax(\mathbb{X}) \cdot \mathbb{V} = \mathbb{Y} \cdot \mathbb{V} = \mathbb{A} (9)

选择符号X和x来表示softmax的自变量,与之前对于softmax的讨论统一;选择y来表示softmax的结果(可惜s已经被用了);选择A来表示attention的结果。

注意:同时我们需要了解,在LLM中,N和M代表的是 batch_size * num_head * seq_length,而K代表的是head_dim。因此N和M是可以运行时动态变量(runtime dynamic variable),而K是编译时静态变量(compile-time static variable)。这个理解对于进行性能优化非常重要。

应用online softmax

如果将online softmax应用在attention上,同时将有如下两遍循环操作:

循环(1)

完全套用online softmax的操作:

xn,m=X[n,m]=Q[n,:]KT[:,m]10 x_{n,m} = \mathbb{X}[n,m] = \mathbb{Q}[n,:] \cdot \mathbb{K}^T[:,m] (10)

注意:此处的符号数量限制,为避免误导,将 n, m, k 分别定义为在N, M, K 三个dimension上的index。

可以看出,虽然matrix S是的shape为[N,K],但是softmax只作用于其inner dimension。也就是说其outer dimension是只是重复的循环操作。因此为了简便,我们只看内循环操作,即矩阵X的一行操作:

xm=Q[n,:]KT[:,m]11 x_m = \mathbb{Q}[n,:] \cdot \mathbb{K}^T[:,m] (11)

其中的符号 n 指代外循环变量 for n in range(N) ,同时符号 m 指代内循环变量 for m in range(M)

gm=max(gm1,xm)12 g_m = max(g_{m-1}, x_m) (12)
sm=sm1×egm1gm+exmgm13 s_{m} = s_{m-1} \times e^{g_{m-1} - g_m} + e^{x_m - g_m} (13)

此处由于符号m被使用了,我们将max()的结果从符号m换成符号g(greatest,因为符号x, a, i都被使用了)

循环(2)

将类似思想套用在 Y•V 这个步骤上,则可以得到:

ym=exmgMsM14 y_m = \frac{e^{x_m - g_M}}{s_M} (14)
A0=0Am=Am1+ym×V[m,:]15 \vec{A_0} = 0 \newline \vec{A_m} = \vec{A_{m-1}} + y_m \times \mathbb{V}[m,:] (15)

需要再次指出的是: m 指代内循环变量 for m in range(M),而 k 指代外循环变量 for k in range(K)。因此(15)计算的是一个vector而不是scalar。这么做的好处就是每次计算出来的 y_m 可以用完就丢弃,同时这个表达方式可以引出我们之后对flash attention的讨论。

如果你跟我一样曾经觉得困惑的话,其实(15)并不难理解,只不过是将矩阵乘法的内外循环互换并将中间结果保存起来。

首先来看矩阵乘法:

A=YVA[n,k]=Y[n,:]V[:,k]=m=1M(Y[n,m]×V[m,k])16 \mathbb{A} = \mathbb{Y} \cdot \mathbb{V} \newline \mathbb{A}[n,k] = \mathbb{Y}[n,:] \cdot \mathbb{V}[:,k] = \sum_{m=1}^M(\mathbb{Y}[n,m] \times \mathbb{V}[m,k]) (16)

matrix_a[:,:] = 0 # initialize matrix A to be zeros
for n in range(N):
    for k in range(K):
        for m in range(M):
            matrix_a[n, k] += matrix_y[n, m] * matrix_v[m, k]
Enter fullscreen mode Exit fullscreen mode

如果我们调换循环K和循环M,结果一致,只是计算顺序变了。我们从一个一个计算A的element,变成了先计算A的一行的中间值(partial sum)然后加总。

matrix_a[:,:] = 0 # initialize matrix A to be zeros
for n in range(N):
    for m in range(M):
        for k in range(K):
            matrix_a[n, k] += matrix_y[n, m] * matrix_v[m, k]
Enter fullscreen mode Exit fullscreen mode

那么数学表示则变成:

A[n,:]=m=1M(Y[n,m]×V[m,:])17 \mathbb{A}[n,:] = \sum_{m=1}^M( \mathbb{Y}[n,m] \times \mathbb{V}[m,:]) (17)

而(15)中的 y_m 则是当前第n行的Y[n,m],因此:

A[n,:]=AM=m=1M(ym×V[m,:])=m=1M(exmgMsM×Vm)18 \mathbb{A}[n,:] = \vec{A_M} = \sum_{m=1}^M(y_m \times \mathbb{V}[m,:]) = \sum_{m=1}^M(\frac{e^{x_m - g_M}}{s_M} \times \vec{V_m}) (18)

所以循环(2)变成了:

Am=m=1M(exmgMsM×Vm)19 \vec{A_m} = \sum_{m=1}^M(\frac{e^{x_m - g_M}}{s_M} \times \vec{V_m}) (19)

优化循环(2)

我们利用oneline softmax同样的思想,将计算 A_M 的过程转变成一个循环即A_m = func(A_{m-1}),然后构造一个新的A'_m来去掉数据依赖,因为我们只需要保证A'_M == A_M即可,而不需要中间过程相等。

Am=i=1m(exigmsm×Vi)20 \vec{A'm} = \sum{i=1}^m(\frac{e^{x_i - g_m}}{s_m} \times \vec{V_i}) (20)

m=M 时,则有:

AM=i=1M(exigMsM×Vi)=AM21 \vec{A'M} = \sum{i=1}^M(\frac{e^{x_i - g_M}}{s_M} \times \vec{V_i}) = \vec{A_M} (21)

如果比较(20)和(19),仅仅是将m替换成了i,将M替换成了m,但是本质原理是用循环计算的方式来避免写入内存(avoid materialize matrix)。

如果我们将(20)展开,则可以得到:

A1=ex1g1s1×V1Am=Am1×sm1×egm1gmsm+exmgmsm×Vm22 \vec{A'1} = \frac{e^{x_1-g_1}}{s_1} \times \vec{V_1} \newline \vec{A'_m} = \vec{A'{m-1}} \times \frac{s_{m-1} \times e^{g_{m-1} - g_m}}{s_m} + \frac{e^{x_m-g_m}}{s_m} \times \vec{V_m} (22)

合并循环(1)和优化后的循环(2)

因为两者的思想和应用方法是一样的,即在循环中计算中间值,仅保证最终结果与目标一致即可,因此我们可以将其合并为一个循环。同时我们将其表达为矩阵乘法形式。以下为同一循环中的四个步骤:

(1) Score

X[n,m]=Q[n,:]KT[:,m]23 \mathbb{X}[n,m] = \mathbb{Q}[n,:] \cdot \mathbb{K}^T[:,m] (23)

(2) Max

G[n,1]=X[n,1]G[n,m]=max(G[n,m1],X[n,m])24 \mathbb{G}[n,1] = \mathbb{X}[n,1] \newline \mathbb{G}[n,m] = max(\mathbb{G}[n,m-1], \mathbb{X}[n,m]) (24)

(3) Sum

S[n,1]=1S[n,m]=S[n,m1]×eG[n,m1]G[n,m]+eX[n,m]G[n,m]25 \mathbb{S}[n,1] = 1 \newline \mathbb{S}[n,m] = \mathbb{S}[n,m-1] \times e^{\mathbb{G}[n,m-1] - \mathbb{G}[n,m]} + e^{\mathbb{X}[n,m] - \mathbb{G}[n,m]} (25)

(4) Attention

A1[n,:]=eX[n,1]G[n,1]S[n,1]×V[1,:]Am[n,:]=Am1[n,:]×S[n,m1]×eG[n,m1]G[n,m]S[n,m]+eX[n,m]G[n,m]S[n,m]×V[m,:]26 \mathbb{A}1 [n,:] = \frac{e^{\mathbb{X}[n,1] - \mathbb{G}[n,1]}}{\mathbb{S}[n,1]} \times \mathbb{V}[1,:] \newline \mathbb{A}_m [n,:] = \mathbb{A}{m-1} [n,:] \times \frac{\mathbb{S}[n,m-1] \times e^{\mathbb{G}[n,m-1] - \mathbb{G}[n,m]}}{\mathbb{S}[n,m]} + \frac{e^{\mathbb{X}[n,m] - \mathbb{G}[n,m]}}{\mathbb{S}[n,m]} \times \mathbb{V}[m,:] (26)

用Python表示:

matrix_q = [[random.random() for _ in range(K)] for _ in range(N)]
matrix_k = [[random.random() for _ in range(K)] for _ in range(M)]
matrix_v = [[random.random() for _ in range(K)] for _ in range(M)]
matrix_a = [[0.0 for _ in range(K)] for _ in range(M)]

for n in range(N):
  for m in range(M):
    x = 0.0
    for k in range(K):
      x += matrix_q[n,k] * matrix_k[m,k]

    g = max(g_prev, x) if m > 0 else x

    e = math.exp(x - g) if m > 0 else 1.0
    s = s_prev * math.exp(g_prev - g) + e if m > 0 else 1.0

    for k in range(K):
      if m > 0:
        matrix_a[n,k] = (matrix_a[n,k] * s_prev * math.exp(g_prev - g) + e * matrix_v[m,k] ) / s
      else:
        matrix_a[n,k] = e * matrix_v[0,k] / s

    g_prev = g
    s_prev = s
Enter fullscreen mode Exit fullscreen mode

Top comments (0)