目录

Flash-Attention学习笔记

Flash Attention学习笔记

https://i-blog.csdnimg.cn/direct/bd913ef0fc0146d29f7016c496406c99.png

fast可以增加模型训练速度,memory efficient 显存高效的,exact:和标准attention得到的结果完全一致,并不降低attention精度。IO-Awareness:通过对IO感知的方式来进行训练的整个算法是以改进IO效率达到的。

首先传统的transformer计算过程:

https://i-blog.csdnimg.cn/direct/68dbccd9db1a4fb189d1ae2ede3dd8f7.png

pytorch写的代码在实际显卡上的attention是如何计算的呢。

https://i-blog.csdnimg.cn/direct/6d7e16b3f89a4542b47ca616692ca326.png

  • SRAM:特点是极快,但容量小、成本高、占地方。它主要用在处理器芯片内部,作为缓存。
  • HBM:特点是带宽极高,但成本非常高。它主要用在高端GPU和AI加速卡上,作为主内存,为大规模并行计算提供海量数据吞吐。
  • 由于权重矩阵和输入数据量非常大,它们通常存储在HBM中
  • 一旦数据被加载到SRAM上,计算核心就会高速访问SRAM,执行密集的矩阵乘法运算
    1. 从HBM加载Q,K到SRAM
    2. 计算出S= QK^T
    3. 将S写到HBM
    4. 将S加载到SRAM
    5. 计算P=softmax(S)
    6. 将P写出到HBM
    7. 从HBM加载P和V到SRAM
    8. 计算O = PV
    9. 把0 写出到HBM
  • 矩阵Q,K,V维度都是N*d,N是序列长度,d是特征维度。
    可以看出中间有很多临时变量的读写,比如S和P矩阵,他们大小都是随着序列长度的平方增长的。中间临时矩阵占用的显存非常大。比如保留中间结果比如SP会占用显存但是还是需要的,因为反向传播需要他们来计算梯度。在模型训练时制约训练速度有两种情况。

https://i-blog.csdnimg.cn/direct/64cd9026c39042338a65c5cba8e838ff.png

在模型训练时制约训练速度有两种情况。
一种情况是compute-Bound,训练速度的瓶颈在于运算,比如对于大的矩阵乘法还有多channel的卷积操作。这些操作都是需要的数据量不大但是计算很复杂。
第二种情况是memory-bound,训练速度的瓶颈在于对HBM数据的读取速度。
从HBM读取数据的速度跟不上运算的速度,算力在等待数据。主要操作有两类:一位是按位的操作比如relu和dropout,还有一类是规约操作比如sum, softmax这些操作都是需要数据很多但是计算相对简单。

attention计算操作主要是memory bound的计算。

上面右侧的图可以看到compute bound的操作比如矩阵乘法占用的时间很短,但是memory bound占据了很长时间。对于memory bound的优化主要通过融合多个操作。

https://i-blog.csdnimg.cn/direct/3afc80bfe272408ea764baabfc98d16e.png

对于Memory-Bound的优化一般是进行fusion融合操作。不对中间结果缓存,减少HBM访问,节约了原来多个操作之间要存取HBM的时间,让多个操作只要存取一次HBM。我们不保存中间结果在反向传播中重新计算。

https://i-blog.csdnimg.cn/direct/8a03354404e44f159f27a01004e77fa3.png

显存中的存储是分级的,有芯片内的缓存SRAM(缓存容量小但是访问快),还有芯片外的HBM缓存(容量大但是访问慢),所以对于优化来说应该尽可能让计算访问芯片内的缓存,尽可能减少访问芯片外HBM的显存。

flash attention 着眼于减少IO量。以及通过访问芯片内缓存而加快IO的速度。

当Q和K矩阵很大时,不分块的传统方法会把大部分时间浪费在等待数据从HBM搬运到SRAM上,GPU强大的计算单元大部分时间在“饿着肚子”等数据。

https://i-blog.csdnimg.cn/direct/407f55fe8a274a6bb906db09c1a54b46.png

为了实现避免attention matrix从HBM读写通过以下两点实现的:

  1. 通过分块计算,融合多个操作,减少中间结果缓存(到HBM)
  2. 反向传播时, 重新计算中间结果。

实现了2-4倍速度提升,10-20倍显存占用的节省(从原来的随序列长度平方增长减小到随序列长度线性增长)。

下面我们来看如何通过矩阵分块和融合多个计算来减少对HBM的访问。
暂时先跳过softmax的操作比较特殊后面单独讨论。

https://i-blog.csdnimg.cn/direct/cbabd18ccbc24c0d96348abaa78adfb2.png

  1. 从HBM中读取Q的前两行、K转置的前三列、V的前三行,然后传入到SRAM上对他们进行计算。

https://i-blog.csdnimg.cn/direct/16b0160b676840608c3ef696044881f0.png

在SRAM中Q和K的转置得到S并不存入HBM,直接和V的分块进行计算。得到了O的前两行:

https://i-blog.csdnimg.cn/direct/dc97ccf3da4b4bd48dc6fcc8d76b6ef4.png

因为O是对所有V的一个加权平均,目前得到的结果就是对V的前三行进行加权平均。O用浅色表示因为还只是一个中间结果,后面还需要更新。

  1. 接下来K和V的分块还保留在SRAM里,从HBM里读取Q的中间两行,经过同样的计算得到O的第三行和第四行的中间结果。

https://i-blog.csdnimg.cn/direct/815b912a265b41b8830d51020993466b.png

然后仍然保留K和V的分块在SRAM里

3, 从HBM中读取Q的最后两行经过同样的计算得到O的最后两行的结果

https://i-blog.csdnimg.cn/direct/3ef936f92cde4e17b8bf848fd138c136.png

  1. 接下来读取K转置的后三列,V的后三行, Q的前两行,得到结果S之后再凑个HBM里读取O之前的保存结果,也就是对V的前三行的加权平均值进行加和。

https://i-blog.csdnimg.cn/direct/28a26ebffdb54c1fbf215bd391caf285.png

得到了O前两行的最终结果。

https://i-blog.csdnimg.cn/direct/0e342912ad854c99a6ec074c14ab87df.png

同样保持K和V分块不变,从HBM里读取下一个分块的Q进行计算,从HBM里读取之前的计算中间结果O和更新后存入HBM。最后继续保持SRAM里的K和V分块不变。最后从HBM里读取Q的最后两行进行计算,继续保持SRAM里面的K和V分块不变。加和更新最终存入HBM。

https://i-blog.csdnimg.cn/direct/46890a89404a473c9eb69108fc0fbb90.png

https://i-blog.csdnimg.cn/direct/1cfa44e5f40f48e1a686897236674768.png

以上完成了attention计算。

通过将矩阵分块以及将多步计算进行融合,中途没有将中间计算结果S存入HBM。大大减少了IO的时间。

接下来看softmax:

https://i-blog.csdnimg.cn/direct/481ec7105d7f4582862ae37eb1b1bddf.png

softmax是按行进行的,只有一行所有的数据都计算完成后才能进行这里的求和计算。所以我们想要让我们之前的矩阵分块对attention多步进行融合计算得以进行的前提必须解决softmax分块计算问题。

softmax的分块计算:

现在我们训练都是混合精度,在FP16下进行,如果X= 12,则E的X次方就大于FP16所能表示的最大的数了。

https://i-blog.csdnimg.cn/direct/bbf3afa3d5c544aab2bbb55745c4c96b.png

为了解决这个数值溢出的问题,人们提出了一种叫做safe softmax的算法。
首先找出从X1到Xn里最大的值m,然后将softmax的分子和分母同时除以E的m次方,softmax结果不变,得到的式子中可以看出e的指数部分就都小于等于0了,这时候用fp16表示就不会有数值溢出的问题了。

在看一下safe softmax的过程:

有一组X通过max(x)求出X里的最大值,通过p(x)将x变化成e的(xi-m(x)),如下图

https://i-blog.csdnimg.cn/direct/c0b8476e29314d56b850f83921278c6e.png

https://i-blog.csdnimg.cn/direct/685657f47959436b9653a848ed9e5ea2.png

对于原始2N个X的正确的softmax的值如上图右侧计算过程。

其中p(x)拼接起来的公式是https://i-blog.csdnimg.cn/direct/8dbe31467e8a4ac6a3832695775ad60f.png

需要分别给https://latex.csdn.net/eq?p%28x%5E%7B1%7D%29https://latex.csdn.net/eq?p%28x%5E%7B2%7D%29一个系数。因为m(x)是https://latex.csdn.net/eq?m%28x%5E%7B1%7D%29https://latex.csdn.net/eq?m%28x%5E%7B2%7D%29的最大值,所以m(x)肯定等于其中一个。假设m(x)=https://latex.csdn.net/eq?m%28x%5E%7B2%7D%29那么后面e的指数项就等于0,那么https://latex.csdn.net/eq?p%28x%29%20%3D%20%5Be%5E%7Bm%28x%5E%7B1%7D%29-m%28x%29%7Dp%28x%5E%7B1%7D%29%2C%20p%28x%5E%7B2%7D%29%5D

因为https://latex.csdn.net/eq?p%28x%5E%7B2%7D%29分块计算的时候减去就是全局最大值,所以此时不需要再进行调整。https://latex.csdn.net/eq?p%28x%5E%7B1%7D%29在分块计算的时候减去的是局部最大值不是全局最大值,那么它和全局最大值比少了多少呢?就是这里的https://latex.csdn.net/eq?m%28x%5E%7B1%7D%29-m%28x%29给他补回来。

所以softmax也可以通过分块来计算了,只是我们需要额外补充几个变量。

https://i-blog.csdnimg.cn/direct/6258792ce5fc44ef9a87f7f490cc8eb4.png

https://i-blog.csdnimg.cn/direct/dd09ef18895f4ddf80831088eff5ce95.png

可以看到flash attention计算量增加了 但是对HBM访问量大幅度减小,训练时间也是大幅度减小。

flash attention2大致思想相似增加了一些工程优化,

https://i-blog.csdnimg.cn/direct/11d4aa255a654d2cb912036734db95cf.png

如果只看单次传输,把一大块数据分成很多小块传输,总的数据量不变,总的理论带宽时间应该是一样的。但是,Flash Attention 分块计算能极大减少 IO 时间的根本原因在于:它通过巧妙的分块,避免了将整个庞大的中间结果矩阵反复从慢速内存(如 HBM)读取和写入