Softmax是Transformer模型架构中非常重要的一环。它所在的Attention模块虽然所需要的计算量不大,但也是不容忽视的一环。同时由于它本身的数学特性所造成的数据依赖,如果按照其原始方法来进行运算,会耗费大量的计算时间,因为它需要三次完整读取数据。
Online normalizer calculation for softmax 提出了online softmax,通过牺牲计算来节省数据读取次数,将三遍完整读取(3 passes)降低到两遍完整读取(2 passes)。
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 则应用了类似的思想,更进一步,利用NVIDIA GPU的本地存储,将读取次数减少为一遍。
那么就让我们详细了解一下其中奥秘吧。本文会出现数学公式,但不要慌张,仅仅是简单的数组知识而已。同时也会辅以可以运行的Python代码,以便理解。
Softmax的数学表达
首先,softmax作用在一维vector上,而非多维tensor。
我们需要(1)先计算每个x的e^x并进行加总得到总和,(2)然后计算除法。
Safe softmax
但是由于e^x的特性,当x相对较大时,e^x就容易溢出,尤其是使用float16甚至更低精度的浮点表达时。因此我们需要对X进行normalization(归一化),并称之为saft softmax。
因此,softmax的三遍读取为(1)统计最大值(2)加总(3)除法
如果用Python来表示的话:
vec_x = [random.random() for _ in range(N)]
# 1st pass
max_x = max(vec_x)
# 2nd pass
sum_ex = sum([math.exp(x - max_x) for x in vec_x])
# 3rd pass
softmax_x = [math.exp(x - max_x) / sum_ex for x in vec_x]
Online softmax
但是,我们其实可以将(1)和(2)合并,在寻找最大值的同时进行加总操作。
如果我们能够构造两个数组M和S,它们的单元数量和X一样:
即,数组M保存的是数组X到当前index为止的最大值,而数组S保存的是类似softmax的加总(除数)。数组S的所有数值都和softmax的第二步加总的出来的结果的不一样,除了最后一个。也就是说,如果我们一个个读取数组X,并计算数组M和数组S,在计算到最后一个的时候,数组S的d_N正好等于softmax的除数。这样我们只需要一遍读取就能完成步骤(1)和步骤(2)。我们选择符号m来表示max最大值;选择符号s来表示sum加总。
这种奇思妙想不得不令我们想起dynamic programming。
如果我们展开数组S的计算过程:
我们可以证明公式6,即当i=N的时候,s_i即为softmax中加总之后的结果。
如果用Python来表示的话:
vec_x = [random.random() for _ in range(N)]
s_prev = 1 # previous s_{i-1}
m_prev = vec_x[0] # previous m_{i-1}
for i in range(1, N): # starting from the 2nd element
x = vec_x[i] # read x
m = max(m_prev, x) # calculate m_i
s = s_prev * math.exp(m_prev - m) + math.exp(x - m) # calculate s_i
s_prev = s # save as s_{i-1}
Attention的数学表达
其中A, Q, K, V都是二维matrix,A, Q, K, V的shape分别依次是[N,K]
, [N,K]
, [M,K]
, [M,K]
。但对于Transformer架构而言,N == M
。mask()
函数旨在加载self-attention或者其他种类的masking,但对masking的探讨超越了本文范围,且因为食element-wise操作,所以不改变本文讨论的内容。
为什么要除以常数√K?
同样是为了归一化,这样在进行 Q·Kᵀ 时(1)不会溢出(2)分布过于极端导致softmax溢出(3)在训练过程中造成不稳定。
简化1:1/√K scaling通常是在进行点乘运算前对Q或者K进行,因此在后面的讨论中不再赘述。
简化2:以后的讨论我们还会简化masking的步骤。
经过简化,同时寻找X和Y符号来表示中间步骤的变量,attention可以表达为:
选择符号X和x来表示softmax的自变量,与之前对于softmax的讨论统一;选择y来表示softmax的结果(可惜s已经被用了);选择A来表示attention的结果。
注意:同时我们需要了解,在LLM中,N和M代表的是 batch_size * num_head * seq_length,而K代表的是head_dim。因此N和M是可以运行时动态变量(runtime dynamic variable),而K是编译时静态变量(compile-time static variable)。这个理解对于进行性能优化非常重要。
应用online softmax
如果将online softmax应用在attention上,同时将有如下两遍循环操作:
循环(1)
完全套用online softmax的操作:
注意:此处的符号数量限制,为避免误导,将 n
, m
, k
分别定义为在N
, M
, K
三个dimension上的index。
可以看出,虽然matrix S是的shape为[N,K]
,但是softmax只作用于其inner dimension。也就是说其outer dimension是只是重复的循环操作。因此为了简便,我们只看内循环操作,即矩阵X的一行操作:
其中的符号 n
指代外循环变量 for n in range(N)
,同时符号 m
指代内循环变量 for m in range(M)
此处由于符号m
被使用了,我们将max()
的结果从符号m
换成符号g
(greatest,因为符号x
, a
, i
都被使用了)
循环(2)
将类似思想套用在 Y•V 这个步骤上,则可以得到:
需要再次指出的是: m
指代内循环变量 for m in range(M)
,而 k
指代外循环变量 for k in range(K)
。因此(15)计算的是一个vector而不是scalar。这么做的好处就是每次计算出来的 y_m
可以用完就丢弃,同时这个表达方式可以引出我们之后对flash attention的讨论。
如果你跟我一样曾经觉得困惑的话,其实(15)并不难理解,只不过是将矩阵乘法的内外循环互换并将中间结果保存起来。
首先来看矩阵乘法:
matrix_a[:,:] = 0 # initialize matrix A to be zeros
for n in range(N):
for k in range(K):
for m in range(M):
matrix_a[n, k] += matrix_y[n, m] * matrix_v[m, k]
如果我们调换循环K和循环M,结果一致,只是计算顺序变了。我们从一个一个计算A的element,变成了先计算A的一行的中间值(partial sum)然后加总。
matrix_a[:,:] = 0 # initialize matrix A to be zeros
for n in range(N):
for m in range(M):
for k in range(K):
matrix_a[n, k] += matrix_y[n, m] * matrix_v[m, k]
那么数学表示则变成:
而(15)中的 y_m
则是当前第n行的Y[n,m]
,因此:
所以循环(2)变成了:
优化循环(2)
我们利用oneline softmax同样的思想,将计算 A_M
的过程转变成一个循环即A_m = func(A_{m-1})
,然后构造一个新的A'_m
来去掉数据依赖,因为我们只需要保证A'_M == A_M
即可,而不需要中间过程相等。
当 m=M
时,则有:
如果比较(20)和(19),仅仅是将m
替换成了i
,将M
替换成了m
,但是本质原理是用循环计算的方式来避免写入内存(avoid materialize matrix)。
如果我们将(20)展开,则可以得到:
合并循环(1)和优化后的循环(2)
因为两者的思想和应用方法是一样的,即在循环中计算中间值,仅保证最终结果与目标一致即可,因此我们可以将其合并为一个循环。同时我们将其表达为矩阵乘法形式。以下为同一循环中的四个步骤:
(1) Score
(2) Max
(3) Sum
(4) Attention
用Python表示:
matrix_q = [[random.random() for _ in range(K)] for _ in range(N)]
matrix_k = [[random.random() for _ in range(K)] for _ in range(M)]
matrix_v = [[random.random() for _ in range(K)] for _ in range(M)]
matrix_a = [[0.0 for _ in range(K)] for _ in range(M)]
for n in range(N):
for m in range(M):
x = 0.0
for k in range(K):
x += matrix_q[n,k] * matrix_k[m,k]
g = max(g_prev, x) if m > 0 else x
e = math.exp(x - g) if m > 0 else 1.0
s = s_prev * math.exp(g_prev - g) + e if m > 0 else 1.0
for k in range(K):
if m > 0:
matrix_a[n,k] = (matrix_a[n,k] * s_prev * math.exp(g_prev - g) + e * matrix_v[m,k] ) / s
else:
matrix_a[n,k] = e * matrix_v[0,k] / s
g_prev = g
s_prev = s
Top comments (0)