Skip to content

参考FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
03-16更新:对思路、原理进行详细的展开描述。

面临问题

Transformer 框架由于核心组件self-attention对耗时及内存占用上都是序列长度N2复杂度,很难将其应用到较长的上下文中,FlashAttn使得Transformer能够建模长序列,这带来以下几个好处:

  • 扩展功能: 使得NLP不仅能够处理段落,同时可以理解书籍、说明书等。
  • 逼近现实: 例如CV上更高的分辨率意味着更好、更强的洞察力
  • 开拓新领域: audio.video,medical imaging data

GPU

TIP

补充GPU工作原理

计算过程中,首先将HBM中的数据加载和写入到SRAM中,在SRAM中完成计算将数据传回并写入HBM。这里SRAM理解为L1 cacheshared memory即可。

贡献

  • 节省显存:减少了额外数据的存储消耗。
  • 精准注意力:在使用稀疏计算时,能够保证结果的准确性。(未更新)
  • 设计计算块:Tilling, extra statistics, combine the results。

思路

  • 尽量使用SRAM,单次传输占满 分块计算
  • 减少内存搬运次数 融合计算

Forward

Attn 横向对比

对比标准Attn,Flash Attn在前向传播中使用l,m代替了中间值P,降低了额外内存的占用,同时,使用融合运算的技巧,减少了数据搬运的需求。

标准Attn


可以看到,在整个过程中,对SRAM

  • 读入: Q, K, S, P, V
  • 写出: S, P, O
  • 数据搬运量:4N2+4Nd, O(N2+Nd)
  • 额外内存消耗:O(N2)
  • 运算复杂度:O(N2d)

Flash Attn

  • 读入:
    • 外循环: K, V
    • 内循环:每次内循环读入一个完整的Q,为TrNd
  • 写出:O, m, l
  • 数据搬运量:O(N2d2M1) (M>>d2)
  • 额外内存消耗:O(N)
  • 运算复杂度:O(N2d)

主要思路

主要原则:充分利用SRAM高速计算能力,保证每次数据传输能够填满SRAM。根据SRAM-size(假定为M),设计Q分块大小Bc=dM4d,设计K,V的分块大小Br=M4d。对QRN×d,将其分大小为Br×d的若干块,对K,V同样进行分块。
事实上,观察Algorithm1可以发现,SRAM中常驻的变量为Kj,Vj,Qi,Oi,大小分别为Bc×d+Bc×d+Br×d+Br×Bc=M2+Br×(Bc+d),注意到,若d=M4d, 此时他们的内存加和恰好为SRAM的内存大小M。这就保证了我们每次循环能够充分利用SRAM的内存空间,实现高速计算。

分块计算面临问题

  • 例如,假设Q=[Q1Q2],K=[K1K2], 则根据S=QKT得到[Q1Q2][K1TK2T]=[S11S12S21S22]
    由于softmax需对整行数据执行操作,此时,分块后的每次循环中,不完全的input(真正的并行发生在串接concatenated上)对softmax操作带来了挑战。
  • O=PV,显然OP有依赖,标准情况下需要待P计算完成后返回到HBM后重载求解O,这导致了额外的显存消耗和数据传输。

问题解决

Safe-softmax

当数据值很大时,对于FP-16数据类型,exp可能会超出数值有效范围,故采用safe-softmax,对于xRB

  1. m(x)=max(xi)
  2. f(x)=[ex1m(x),...,exBm(x)],l(x)=if(x)i
  3. softmax(x)=f(x)l(x)

考虑对1,2进行融合,使得2步l(x)不再对m(x)产生依赖,数学上,需要获得关于l(x)的递推式。考虑假如x1,x2RB

m(x)=m([x1x2])=max(m(x1),m(x2))f(x)=[em(x1)m(x)f(x1),em(x2)m(x)f(x2)]l(x)=em(x1)m(x)l(x1)+em(x2)m(x)l(x2)

注意:这部分对应算法的第10到11行。

O = PV

由于最终需要的结果为O,而计算O依赖于整个P,那么能否像处理softmax一样改为递推式,使得每一个循环产生的Pi无需写出重载,在SRAM中完成对Oi的计算,并更新O
考虑在外循环为j时,此时SRAM中得到的数据有mj=rowmax(S:,:j)RNlj=rowsum(exp(S:,:jmj)), Oj=P:,:jV:,:jRN×d。其中,S:,:j代表在列上截断。
j+1循环,更新mj+1=max(mj,m~)=rowmax(S:,:j+1), lj+1=emjmj+1lj+em~mj+1l~=rowsom(exp(S:,:j+1mj+1))

Oj+1=P:,:j+1V:,:j+1=softmax(S:,:j+1)V:,:j+1=diag1lj+1[exp([S:,:j,S:,j+1mj+1])][V:,:jV:,j+1]=diag1lj+1[emj+1eS:,:jV:,:j+eS:,j+1mj+1V:,j+1]=diag1(lj+1)[diag(lj)emjmj+1Oj+eS:,j+1mj+1V:,j+1]=diag1(lj+1)[diag(lj)emjmj+1Oj+em~mj+1eS:,j+1m~V:,j+1]=diag1(lj+1)[diag(lj)emjmj+1Oj+em~mj+1P~:,j+1V:,j+1]

注意:

  • (5)到(6):凑Oj=diag1(lj)exp[S:,:jmj]V:,:j
  • (6)到(7):上式包含S:,j+1,计算过程绕不开P,故使用P替代,使得S可被释放。
  • 这部分对应算法第12行

Backward

如果不清楚基本的标量对向量,softmax求导,请参考这两篇文章CSDN, blog

横向对比

可以看到,在Backward过程,FlashAttn减少了数据搬运,增加了计算量(重计算),由于此时主要为Mem-bound,故有利于性能提升。

标准 Attn

  • HBM:Q, K, V, O, S, P
  • 读入:P,dO,V,P,dP,dS,K,dS,Q
  • 写出:dV,dP,dS,dQ,dK

Flash Attn

  • HBM:m,l,Q,K,V,O,dO
  • 重计算:对应算法11到15行。
    这里采用重计算的方式,即不直接搬运P,而是在反向传递过程中,经由S=QKT得到S后结合l,m得到P

Backward 过程分块梯度传递

  • V:对应算法16行
    当外循环为0时,V0P00,P10,P20相乘得到O0,O1,O2。则dV0=(P00T)dO0+(P10T)dO1+(P20T)dO2,进而
dVj=i(PijdOj)
  • P:对应算法17到18行
    对于Pij,在外循环为j时仅与Vj有关,内循环为iOi有关
dPij=dOiVjT
  • S:对应算法19到20行
    • 第一个等式为对softmax求导
    • 修改为点乘是为了扩展到块(多行) 设si,pi,oiS,P,O的某一行,注意,不表示分块。
dsi=dpi(diag(pi)piTpi)=dpidiag(pi)dpipiTpi=dpidiag(pi)doiVTpiTpi=dpidiag(pi)doioiTpi=pidpipirowsum(doioi)=pi[dpirowsum(doioi)]

则最终

dSij=Pij[dPijrowsum(doioi)]
  • Q:对应算法21行 S=QKT, 对于Qi,他与Sij有关,与Kj有关。则
dQi=jdSijKj
  • K:对应算法22行 对于外循环jKjSij有关,与Oi有关,则
dKj=idSijTQi

实验成果

Speed up

  • Fig2: 对比传统的Attn,尽管FlashAttn在增加计算(如后向传播中的重新计算),但HBM的读写仅为传统方法的19,速度上提升了6倍。
  • E.5: 在不同的GPU下,不同的组件(是否含有Mask,Dropout),不同序列长度的所有情况下,FlashAttention较基准情况加速24倍。
  • 4.1 BERT: 达到一定精度所需要的训练时间更短。比创下Nvida记录的MLPerf 1.1 加速了15%。
  • GPT-2: 在GPT-2 small 和 GPT-2 midium 数据集上与Huggingface 和 Megatron-LM 对比,保持同等精度且速度较Huggingface为2.03.5×

Longer Sequences

  • 4.2 LM with Long Context: 通常增长上下文的长度后训练速度会变慢但可以得到一个更好的模型(Table 5 展示了在更长的上下文训练的模型具备更高的分类精度)。列表展示了FlashAttn4k文本长度的情况下具备比1k文本长度下Megatron-LM更快的训练速度,更长的序列代表模型更高的质量。
  • 第一个解决Path-XTransformer