Softmax是Transformer模型架构中非常重要的一环。它所在的Attention模块虽然所需要的计算量不大,但也是不容忽视的一环。同时由于它本身的数学特性所造成的数据依赖,如果按照其原始方法来进行运算,会耗费大量的计算时间,因为它需要三次完整读取数据。
Online normalizer calculation for softmax 提出了online softmax,通过牺牲计算来节省数据读取次数,将三遍完整读取(3 passes)降低到两遍完整读取(2 passes)。
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 则应用了类似的思想,更进一步,利用NVIDIA GPU的本地存储,将读取次数减少为一遍。
那么就让我们详细了解一下其中奥秘吧。本文会出现数学公式,但不要慌张,仅仅是简单的数组知识而已。同时也会辅以可以运行的Python代码,以便理解。
Softmax的数学表达
首先,softmax作用在一维vector上,而非多维tensor。
s o f t m a x ( X ⃗ [ 1 : N ] ) = e x i ∑ j − 1 N e x j ( 1 )
softmax( \vec{X} [1:N] ) = \frac {e^{x_i}} {\sum_{j-1}^{N} e^{x_j} }(1)
so f t ma x ( X [ 1 : N ]) = ∑ j − 1 N e x j e x i ( 1 )
我们需要(1)先计算每个x 的e^x 并进行加总得到总和,(2)然后计算除法。
Safe softmax
但是由于e^x 的特性,当x 相对较大时,e^x 就容易溢出,尤其是使用float16甚至更低精度的浮点表达时。因此我们需要对X 进行normalization(归一化),并称之为saft softmax。
s o f t m a x ( X ⃗ [ 1 : N ] ) = e x i − max ( X ⃗ ) ∑ j − 1 N e x j − max ( X ⃗ ) ( 2 )
softmax( \vec{X} [1:N] ) = \frac {e^{x_i - \max(\vec{X})}} {\sum_{j-1}^{N} e^{x_j - \max(\vec{X})} } (2)
so f t ma x ( X [ 1 : N ]) = ∑ j − 1 N e x j − m a x ( X ) e x i − m a x ( X ) ( 2 )
因此,softmax的三遍读取为(1)统计最大值(2)加总(3)除法
如果用Python来表示的话:
vec_x = [ random . random () for _ in range ( N )]
# 1st pass
max_x = max ( vec_x )
# 2nd pass
sum_ex = sum ([ math . exp ( x - max_x ) for x in vec_x ])
# 3rd pass
softmax_x = [ math . exp ( x - max_x ) / sum_ex for x in vec_x ]
Enter fullscreen mode
Exit fullscreen mode
Online softmax
但是,我们其实可以将(1)和(2)合并,在寻找最大值的同时进行加总操作。
如果我们能够构造两个数组M 和S ,它们的单元数量和X 一样:
m i = max ( x ⃗ [ 1 : i ] ) s i = ∑ j = 1 i e x j − m i ( 3 )
m_i = \max(\vec{x}[1:i])
\newline
s_i = \sum_{j=1}^{i}{e^{x_j - m_i}} (3)
m i = max ( x [ 1 : i ]) s i = j = 1 ∑ i e x j − m i ( 3 )
即,数组M 保存的是数组X 到当前index为止的最大值,而数组S 保存的是类似 softmax的加总(除数)。数组S 的所有数值都和softmax的第二步加总的出来的结果的不一样,除了最后一个。也就是说,如果我们一个个读取数组X,并计算数组M 和数组S ,在计算到最后一个的时候,数组S 的d_N 正好等于softmax的除数。这样我们只需要一遍读取就能完成步骤(1)和步骤(2)。我们选择符号m来表示max最大值;选择符号s来表示sum加总。
这种奇思妙想不得不令我们想起dynamic programming。
如果我们展开数组S 的计算过程:
s 1 = 1 ( 4 )
s_1 = 1 (4)
s 1 = 1 ( 4 )
s i = s i − 1 × e m i − 1 − m i + e x i − m i ( 5 )
s_i = s_{i-1} \times e^{m_{i-1} - m_i} + e^{x_i - m_i} (5)
s i = s i − 1 × e m i − 1 − m i + e x i − m i ( 5 )
s N = ∑ j = 1 N e x j − m N ( 6 )
s_N = \sum_{j=1}^{N} e^{x_j - m_N} (6)
s N = j = 1 ∑ N e x j − m N ( 6 )
我们可以证明公式6,即当i=N 的时候,s_i 即为softmax中加总之后的结果。
s N = s N − 1 × e m N − 1 − m N + e x N − m N = ∑ j = 1 N − 1 ( e x j − m N − 1 × e m N − 1 − m N ) + e x N − m N = ∑ j = 1 N − 1 e x j − m N + e x N − m N = ∑ j = 1 N e x j − m N ( 7 )
s_N = s_{N-1} \times e^{m_{N-1} - m_N} + e^{x_N - m_N} \newline
= \sum_{j=1}^{N-1}{(e^{x_j - m_{N-1}} \times e^{m_{N-1} - m_N})} + e^{x_N - m_N} \newline
= \sum_{j=1}^{N-1}{e^{x_j - m_{N}} } + e^{x_N - m_N} \newline
= \sum_{j=1}^{N} e^{x_j - m_N} (7)
s N = s N − 1 × e m N − 1 − m N + e x N − m N = j = 1 ∑ N − 1 ( e x j − m N − 1 × e m N − 1 − m N ) + e x N − m N = j = 1 ∑ N − 1 e x j − m N + e x N − m N = j = 1 ∑ N e x j − m N ( 7 )
如果用Python来表示的话:
vec_x = [ random . random () for _ in range ( N )]
s_prev = 1 # previous s_{i-1}
m_prev = vec_x [ 0 ] # previous m_{i-1}
for i in range ( 1 , N ): # starting from the 2nd element
x = vec_x [ i ] # read x
m = max ( m_prev , x ) # calculate m_i
s = s_prev * math . exp ( m_prev - m ) + math . exp ( x - m ) # calculate s_i
s_prev = s # save as s_{i-1}
Enter fullscreen mode
Exit fullscreen mode
Attention的数学表达
A = s o f t m a x ( m a s k ( Q ⋅ K T K ) ) ⋅ V ( 8 )
\mathbb{A} = softmax(mask(\frac{\mathbb{Q} \cdot \mathbb{K}^T}{\sqrt{K}})) \cdot \mathbb{V} (8)
A = so f t ma x ( ma s k ( K Q ⋅ K T )) ⋅ V ( 8 )
其中A, Q, K, V都是二维matrix,A, Q, K, V的shape分别依次是[N,K], [N,K], [M,K], [M,K]。但对于Transformer架构而言,N == M。mask() 函数旨在加载self-attention或者其他种类的masking,但对masking的探讨超越了本文范围,且因为食element-wise操作,所以不改变本文讨论的内容。
为什么要除以常数√K ?
同样是为了归一化,这样在进行 Q·Kᵀ 时(1)不会溢出(2)分布过于极端导致softmax溢出(3)在训练过程中造成不稳定。
简化1 :1/√K scaling通常是在进行点乘运算前对Q或者K进行,因此在后面的讨论中不再赘述。
简化2 :以后的讨论我们还会简化masking的步骤。
经过简化,同时寻找X和Y符号来表示中间步骤的变量,attention可以表达为:
s o f t m a x ( Q ⋅ K T ) ⋅ V = s o f t m a x ( X ) ⋅ V = Y ⋅ V = A ( 9 )
softmax(\mathbb{Q} \cdot \mathbb{K}^T) \cdot \mathbb{V} = softmax(\mathbb{X}) \cdot \mathbb{V} = \mathbb{Y} \cdot \mathbb{V} = \mathbb{A} (9)
so f t ma x ( Q ⋅ K T ) ⋅ V = so f t ma x ( X ) ⋅ V = Y ⋅ V = A ( 9 )
选择符号X和x来表示softmax的自变量,与之前对于softmax的讨论统一;选择y来表示softmax的结果(可惜s已经被用了);选择A来表示attention的结果。
注意 :同时我们需要了解,在LLM中,N和M代表的是 batch_size * num_head * seq_length,而K代表的是head_dim。因此N和M是可以运行时动态变量(runtime dynamic variable),而K是编译时静态变量(compile-time static variable)。这个理解对于进行性能优化非常重要。
应用online softmax
如果将online softmax应用在attention上,同时将有如下两遍循环操作:
循环(1)
完全套用online softmax的操作:
x n , m = X [ n , m ] = Q [ n , : ] ⋅ K T [ : , m ] ( 10 )
x_{n,m} = \mathbb{X}[n,m] = \mathbb{Q}[n,:] \cdot \mathbb{K}^T[:,m] (10)
x n , m = X [ n , m ] = Q [ n , : ] ⋅ K T [ : , m ] ( 10 )
注意:此处的符号数量限制,为避免误导,将 n, m, k 分别定义为在N, M, K 三个dimension上的index。
可以看出,虽然matrix S是的shape为[N,K],但是softmax只作用于其inner dimension。也就是说其outer dimension是只是重复的循环操作。因此为了简便,我们只看内循环操作,即矩阵X的一行操作:
x m = Q [ n , : ] ⋅ K T [ : , m ] ( 11 )
x_m = \mathbb{Q}[n,:] \cdot \mathbb{K}^T[:,m] (11)
x m = Q [ n , : ] ⋅ K T [ : , m ] ( 11 )
其中的符号 n 指代外循环变量 for n in range(N) ,同时符号 m 指代内循环变量 for m in range(M)
g m = m a x ( g m − 1 , x m ) ( 12 )
g_m = max(g_{m-1}, x_m) (12)
g m = ma x ( g m − 1 , x m ) ( 12 )
s m = s m − 1 × e g m − 1 − g m + e x m − g m ( 13 )
s_{m} = s_{m-1} \times e^{g_{m-1} - g_m} + e^{x_m - g_m} (13)
s m = s m − 1 × e g m − 1 − g m + e x m − g m ( 13 )
此处由于符号m被使用了,我们将max()的结果从符号m换成符号g(greatest,因为符号x, a, i都被使用了)
循环(2)
将类似思想套用在 Y•V 这个步骤上,则可以得到:
y m = e x m − g M s M ( 14 )
y_m = \frac{e^{x_m - g_M}}{s_M} (14)
y m = s M e x m − g M ( 14 )
A 0 ⃗ = 0 A m ⃗ = A m − 1 ⃗ + y m × V [ m , : ] ( 15 )
\vec{A_0} = 0 \newline
\vec{A_m} = \vec{A_{m-1}} + y_m \times \mathbb{V}[m,:] (15)
A 0 = 0 A m = A m − 1 + y m × V [ m , : ] ( 15 )
需要再次指出的是: m 指代内循环变量 for m in range(M),而 k 指代外循环变量 for k in range(K)。因此(15)计算的是一个vector而不是scalar。这么做的好处就是每次计算出来的 y_m 可以用完就丢弃,同时这个表达方式可以引出我们之后对flash attention的讨论。
如果你跟我一样曾经觉得困惑的话,其实(15)并不难理解,只不过是将矩阵乘法的内外循环互换并将中间结果保存起来。
首先来看矩阵乘法:
A = Y ⋅ V A [ n , k ] = Y [ n , : ] ⋅ V [ : , k ] = ∑ m = 1 M ( Y [ n , m ] × V [ m , k ] ) ( 16 )
\mathbb{A} = \mathbb{Y} \cdot \mathbb{V} \newline
\mathbb{A}[n,k] = \mathbb{Y}[n,:] \cdot \mathbb{V}[:,k] = \sum_{m=1}^M(\mathbb{Y}[n,m] \times \mathbb{V}[m,k]) (16)
A = Y ⋅ V A [ n , k ] = Y [ n , : ] ⋅ V [ : , k ] = m = 1 ∑ M ( Y [ n , m ] × V [ m , k ]) ( 16 )
matrix_a [:,:] = 0 # initialize matrix A to be zeros
for n in range ( N ):
for k in range ( K ):
for m in range ( M ):
matrix_a [ n , k ] += matrix_y [ n , m ] * matrix_v [ m , k ]
Enter fullscreen mode
Exit fullscreen mode
如果我们调换循环K和循环M,结果一致,只是计算顺序变了。我们从一个一个计算A的element,变成了先计算A的一行的中间值(partial sum)然后加总。
matrix_a[:,:] = 0 # initialize matrix A to be zeros
for n in range(N):
for m in range(M):
for k in range(K):
matrix_a[n, k] += matrix_y[n, m] * matrix_v[m, k]
Enter fullscreen mode
Exit fullscreen mode
那么数学表示则变成:
A [ n , : ] = ∑ m = 1 M ( Y [ n , m ] × V [ m , : ] ) ( 17 )
\mathbb{A}[n,:] = \sum_{m=1}^M( \mathbb{Y}[n,m] \times \mathbb{V}[m,:]) (17)
A [ n , : ] = m = 1 ∑ M ( Y [ n , m ] × V [ m , : ]) ( 17 )
而(15)中的 y_m 则是当前第n行的Y[n,m],因此:
A [ n , : ] = A M ⃗ = ∑ m = 1 M ( y m × V [ m , : ] ) = ∑ m = 1 M ( e x m − g M s M × V m ⃗ ) ( 18 )
\mathbb{A}[n,:] = \vec{A_M} = \sum_{m=1}^M(y_m \times \mathbb{V}[m,:]) = \sum_{m=1}^M(\frac{e^{x_m - g_M}}{s_M} \times \vec{V_m}) (18)
A [ n , : ] = A M = m = 1 ∑ M ( y m × V [ m , : ]) = m = 1 ∑ M ( s M e x m − g M × V m ) ( 18 )
所以循环(2)变成了:
A m ⃗ = ∑ m = 1 M ( e x m − g M s M × V m ⃗ ) ( 19 )
\vec{A_m} = \sum_{m=1}^M(\frac{e^{x_m - g_M}}{s_M} \times \vec{V_m}) (19)
A m = m = 1 ∑ M ( s M e x m − g M × V m ) ( 19 )
优化循环(2)
我们利用oneline softmax同样的思想,将计算 A_M 的过程转变成一个循环即A_m = func(A_{m-1}),然后构造一个新的A'_m来去掉数据依赖,因为我们只需要保证A'_M == A_M即可,而不需要中间过程相等。
A ′ m ⃗ = ∑ i = 1 m ( e x i − g m s m × V i ⃗ ) ( 20 )
\vec{A'm} = \sum{i=1}^m(\frac{e^{x_i - g_m}}{s_m} \times \vec{V_i}) (20)
A ′ m = ∑ i = 1 m ( s m e x i − g m × V i ) ( 20 )
当 m=M 时,则有:
A ′ M ⃗ = ∑ i = 1 M ( e x i − g M s M × V i ⃗ ) = A M ⃗ ( 21 )
\vec{A'M} = \sum{i=1}^M(\frac{e^{x_i - g_M}}{s_M} \times \vec{V_i}) = \vec{A_M} (21)
A ′ M = ∑ i = 1 M ( s M e x i − g M × V i ) = A M ( 21 )
如果比较(20)和(19),仅仅是将m替换成了i,将M替换成了m,但是本质原理是用循环计算的方式来避免写入内存(avoid materialize matrix)。
如果我们将(20)展开,则可以得到:
A ′ 1 ⃗ = e x 1 − g 1 s 1 × V 1 ⃗ A m ′ ⃗ = A ′ m − 1 ⃗ × s m − 1 × e g m − 1 − g m s m + e x m − g m s m × V m ⃗ ( 22 )
\vec{A'1} = \frac{e^{x_1-g_1}}{s_1} \times \vec{V_1} \newline
\vec{A'_m} = \vec{A'{m-1}} \times \frac{s_{m-1} \times e^{g_{m-1} - g_m}}{s_m} + \frac{e^{x_m-g_m}}{s_m} \times \vec{V_m} (22)
A ′ 1 = s 1 e x 1 − g 1 × V 1 A m ′ = A ′ m − 1 × s m s m − 1 × e g m − 1 − g m + s m e x m − g m × V m ( 22 )
合并循环(1)和优化后的循环(2)
因为两者的思想和应用方法是一样的,即在循环中计算中间值,仅保证最终结果与目标一致即可,因此我们可以将其合并为一个循环。同时我们将其表达为矩阵乘法形式。以下为同一循环中的四个步骤:
(1) Score
X [ n , m ] = Q [ n , : ] ⋅ K T [ : , m ] ( 23 )
\mathbb{X}[n,m] = \mathbb{Q}[n,:] \cdot \mathbb{K}^T[:,m] (23)
X [ n , m ] = Q [ n , : ] ⋅ K T [ : , m ] ( 23 )
(2) Max
G [ n , 1 ] = X [ n , 1 ] G [ n , m ] = m a x ( G [ n , m − 1 ] , X [ n , m ] ) ( 24 )
\mathbb{G}[n,1] = \mathbb{X}[n,1]
\newline
\mathbb{G}[n,m] = max(\mathbb{G}[n,m-1], \mathbb{X}[n,m]) (24)
G [ n , 1 ] = X [ n , 1 ] G [ n , m ] = ma x ( G [ n , m − 1 ] , X [ n , m ]) ( 24 )
(3) Sum
S [ n , 1 ] = 1 S [ n , m ] = S [ n , m − 1 ] × e G [ n , m − 1 ] − G [ n , m ] + e X [ n , m ] − G [ n , m ] ( 25 )
\mathbb{S}[n,1] = 1
\newline
\mathbb{S}[n,m] = \mathbb{S}[n,m-1] \times e^{\mathbb{G}[n,m-1] - \mathbb{G}[n,m]} + e^{\mathbb{X}[n,m] - \mathbb{G}[n,m]} (25)
S [ n , 1 ] = 1 S [ n , m ] = S [ n , m − 1 ] × e G [ n , m − 1 ] − G [ n , m ] + e X [ n , m ] − G [ n , m ] ( 25 )
(4) Attention
A 1 [ n , : ] = e X [ n , 1 ] − G [ n , 1 ] S [ n , 1 ] × V [ 1 , : ] A m [ n , : ] = A m − 1 [ n , : ] × S [ n , m − 1 ] × e G [ n , m − 1 ] − G [ n , m ] S [ n , m ] + e X [ n , m ] − G [ n , m ] S [ n , m ] × V [ m , : ] ( 26 )
\mathbb{A}1 [n,:] = \frac{e^{\mathbb{X}[n,1] - \mathbb{G}[n,1]}}{\mathbb{S}[n,1]} \times \mathbb{V}[1,:]
\newline
\mathbb{A}_m [n,:] = \mathbb{A}{m-1} [n,:] \times \frac{\mathbb{S}[n,m-1] \times e^{\mathbb{G}[n,m-1] - \mathbb{G}[n,m]}}{\mathbb{S}[n,m]} + \frac{e^{\mathbb{X}[n,m] - \mathbb{G}[n,m]}}{\mathbb{S}[n,m]} \times \mathbb{V}[m,:] (26)
A 1 [ n , : ] = S [ n , 1 ] e X [ n , 1 ] − G [ n , 1 ] × V [ 1 , : ] A m [ n , : ] = A m − 1 [ n , : ] × S [ n , m ] S [ n , m − 1 ] × e G [ n , m − 1 ] − G [ n , m ] + S [ n , m ] e X [ n , m ] − G [ n , m ] × V [ m , : ] ( 26 )
用Python表示:
matrix_q = [[ random . random () for _ in range ( K )] for _ in range ( N )]
matrix_k = [[ random . random () for _ in range ( K )] for _ in range ( M )]
matrix_v = [[ random . random () for _ in range ( K )] for _ in range ( M )]
matrix_a = [[ 0.0 for _ in range ( K )] for _ in range ( M )]
for n in range ( N ):
for m in range ( M ):
x = 0.0
for k in range ( K ):
x += matrix_q [ n , k ] * matrix_k [ m , k ]
g = max ( g_prev , x ) if m > 0 else x
e = math . exp ( x - g ) if m > 0 else 1.0
s = s_prev * math . exp ( g_prev - g ) + e if m > 0 else 1.0
for k in range ( K ):
if m > 0 :
matrix_a [ n , k ] = ( matrix_a [ n , k ] * s_prev * math . exp ( g_prev - g ) + e * matrix_v [ m , k ] ) / s
else :
matrix_a [ n , k ] = e * matrix_v [ 0 , k ] / s
g_prev = g
s_prev = s
Enter fullscreen mode
Exit fullscreen mode
Top comments (0)