本文转自:Coggle数据科学
在 Transformer 架构中,注意力机制的计算复杂度与序列长度(即文本长度)呈平方关系()。这意味着,当模型需要处理更长的文本时(比如从几千个词到几万个词),计算时间和所需的内存会急剧增加。最开始的标准注意力机制存在两个主要问题:
FlashAttention 使得处理长达数万甚至数十万个 token 的超长文本成为可能。这解锁了新的应用场景,例如分析法律文档、总结长篇小说或处理整个代码库。
FlashAttention 使得模型的训练和推理速度更快,尤其是在长序列场景下。例如,FlashAttention-2 在长序列上比标准实现快 10 倍,使得训练成本更低,用户体验更好。
最新的 FlashAttention-3 利用了新硬件(如 NVIDIA H100)的 FP8 精度,进一步提升了性能,同时通过特殊的算法保持了计算的准确性,让模型训练更加高效。
许多研究提出了近似注意力方法,试图通过减少计算量(FLOPs)来提高效率。然而,这些方法通常忽略了GPU不同层级内存(如高速的片上SRAM和相对较慢的高带宽HBM)之间的I/O开销,导致它们在实际运行时并没有带来显著的加速。
FlashAttention的核心思想是I/O感知,即在设计算法时,将数据在不同层级内存之间的读写开销考虑在内。论文指出,在现代GPU上,计算速度已经远超内存访问速度,因此大多数操作都受限于内存访问。FlashAttention通过以下两个关键技术来解决这一问题:
由于GPU的计算速度增长快于内存速度,许多操作的性能瓶颈在于内存访问,而不是计算本身。因此,如何高效利用快速的SRAM变得至关重要。
根据算术强度(每字节内存访问的算术运算次数),操作可分为两类:
给定查询 Q、键 K 和值 V 矩阵,注意力的计算分三步:
标准实现(如“Algorithm 0”所示)将每一步都作为一个独立的GPU核函数,并物化(materialize)中间矩阵 S 和 P 到HBM中。
这种实现方式导致了两个主要问题:
FlashAttention旨在减少对GPU高带宽内存(HBM)的读写,实现对确切注意力(exact attention)的快速、内存高效的计算。为此,它采用了两种关键技术:
通过Tiling和Recomputation,FlashAttention能够将所有计算步骤(矩阵乘法、Softmax、可选的遮蔽和Dropout)融合成一个单一的CUDA核函数。这避免了在每个步骤之间反复地将数据写入HBM。
lashAttention在BERT-large模型上的训练速度超过了MLPerf 1.1的记录保持者。与Nvidia的实现相比,FlashAttention的训练时间缩短了15%,这证明了其在标准长序列任务上的卓越性能。
FlashAttention在训练GPT-2模型时,相比于流行的HuggingFace和Megatron-LM实现,实现了显著的端到端加速。
在Long-Range Arena基准测试中,FlashAttention相比于标准的Transformer实现,实现了2.4倍的加速。此外,块稀疏FlashAttention的表现甚至优于所有已测试的近似注意力方法,证明了其在处理超长序列时的优越性。
lashAttention的内存占用与序列长度呈线性关系,而标准实现是平方关系。这使得FlashAttention的内存效率比标准方法高出20倍。
第一代FlashAttention通过利用 GPU 内存层次结构的特性,显著降低了内存占用(从二次方降为线性)并实现了 2-4 倍的加速,且没有引入任何近似。
然而,FlashAttention 的效率仍然不如优化的矩阵乘法(GEMM)操作,其浮点运算性能(FLOPs/s)仅能达到理论峰值的 25-40%。这主要是因为 FlashAttention 存在不优化的工作划分(work partitioning),导致 GPU 线程块(thread blocks)和线程束(warps)之间的并行度不足、占用率低或产生不必要的共享内存读写。
为了解决这些问题,论文提出了FlashAttention-2,通过以下改进实现了更好的工作划分:
FlashAttention-2对在线 Softmax 技巧进行了两处微调:
第一代 FlashAttention 仅在批处理大小和注意力头数量上进行并行化。当序列长度很长时,批处理大小通常很小,导致 GPU 资源的利用率(occupancy)不高。FlashAttention-2 通过在序列长度维度上增加并行化来解决这个问题。
除了线程块级别的并行,FlashAttention-2 还优化了线程块内部线程束之间的工作分配,以减少共享内存的读写。
FlashAttention-2 比第一代 FlashAttention 快1.7-3.0 倍,比 Triton 实现的 FlashAttention 快1.3-2.5 倍。
在 A100 GPU 上,FlashAttention-2 在前向传播中达到了230 TFLOPs/s的峰值,相当于理论最大吞吐量的73%。在后向传播中,它达到了理论最大吞吐量的 63%。
虽然之前的 FlashAttention 通过减少内存读写来加速计算,但它未能充分利用现代硬件(如 Hopper GPU)的新特性。例如,FlashAttention-2 在 H100 GPU 上的利用率仅为 35%。
与 FlashAttention-2 类似,FlashAttention-3 也将任务并行化到不同的线程块(CTA),但其创新之处在于在单个线程块内部,将线程束(warps)划分为不同的角色。
生产者和消费者通过一个循环缓冲区(circular buffer)进行同步。生产者将数据放入缓冲区,消费者从中取出。当缓冲区中的一个“阶段”被消费后,生产者就可以继续向其中加载新数据。
在标准 FlashAttention 中,GEMM 和 Softmax 存在顺序依赖:Softmax 必须在第一个 GEMM 计算完成后才能开始,而第二个 GEMM 必须等待 Softmax 的结果。
FlashAttention-3 通过在寄存器中使用额外的缓冲区,打破了这种依赖关系。在每次循环中,它异步启动下一个 GEMM 的计算,而同时执行当前 GEMM 结果的 Softmax 和更新操作。这样,GEMM 和 Softmax 的执行就可以重叠,提高了效率。
FP8 的 WGMMA(Warp Group Matrix-Multiply-Accumulate)指令要求输入矩阵具有特定的k-major 布局,而输入张量通常是mn-major 布局。
FlashAttention-3 选择在 GPU 内核中(in-kernel)进行转置。它利用 LDSM/STSM 指令,这些指令能够高效地在 SMEM 和 RMEM(寄存器)之间进行数据传输,并在传输过程中完成布局转置,避免了代价高昂的 HBM 读写。
同于传统的逐张量(per-tensor)量化,FlashAttention-3 对每个块进行单独量化。这使得每个块可以有自己的缩放因子,从而更有效地处理离群值,减少量化误差。
FlashAttention-3 的前向传播速度比 FlashAttention-2 快1.5-2.0 倍,后向传播快1.5-1.75 倍。FP16 版本的 FlashAttention-3 达到了740 TFLOPs/s的峰值,相当于 H100 GPU 理论最大吞吐量的 **75%**。
在处理中长序列(1k 及以上)时,FlashAttention-3 的性能甚至超过了 NVIDIA 自家闭源、针对 H100 优化的cuDNN库。
免责声明:本文为转载,非本网原创内容,不代表本网观点。其原创性以及文中陈述文字和内容未经本站证实,对本文以及其中全部或者部分内容、文字的真实性、完整性、及时性本站不作任何保证或承诺,请读者仅作参考,并请自行核实相关内容。
如有疑问请发送邮件至:bangqikeconnect@gmail.com