本文来自微信公众号: 歪睿老哥 ,作者:歪睿老哥,原文标题:《FlashAttention 论文精读:一个 IO 感知的注意力算法,如何改变了大模型的训练速度》
故事是这样的。
2022夏天,斯坦福的一帮人发了篇论文,标题叫《FlashAttention:Fast and Memory-Efficient Exact Attention with IO-Awareness》。
这名字听着挺无聊对吧,一堆缩写,一个技术名词,跟每天几百篇论文里随便哪篇长得都一样。
但我看了之后,觉得这件事挺有意思的。
因为它解决了一个几乎所有做Transformer的人都踩过坑的问题:大模型推理慢,到底慢在哪里。
很多人第一反应是"算不动",觉得算力不够,换更大的GPU就行。
但这篇文章的作者Tri Dao一帮人说了句不一样的话:
不是算不动,是数据搬不动。
他们管这个叫IO-Awareness——在写注意力算法的时候,不要只盯着FLOP数看,要把GPU显存层级之间的读写次数也算进去。
听起来像废话对吧,但就是这个"废话",让Transformer的训练速度提升了3倍,内存消耗降低了20倍,还让Transformer第一次在64K长度的序列上跑出了好结果。
64K。
之前所有方法,要么跑不了,要么跑出来跟随机猜差不多。
这篇文章,我想把这个论文的核心内容,用人话讲清楚。
1.背景:Transformer的瓶颈到底在哪里
问题:注意力机制的平方复杂度
先看最基本的注意力公式:
O=softmax(QK^T/sqrt(d))V
三个矩阵相乘,再加一个softmax。看起来很简单对吧。
但问题在于,Q和K的矩阵乘法会生成一个N×N的注意力矩阵,其中N是序列长度。
这意味着时间和空间复杂度都是O(N²)。
序列长度翻倍,计算量翻四倍。
这个平方复杂度是Transformer的天然缺陷,从2017年那篇"Attention Is All You Need"出来就没变过。
现有方案的局限
过去几年,一堆人想解决这个问题。
方案主要分两类:
第一类,近似注意力。用稀疏化、低秩分解、核函数近似等等手段,把注意力矩阵从N×N压缩到接近N×1。
这类方法理论上把计算复杂度降到了线性或近线性,但实际跑起来,墙钟时间(wall-clock time)并没有明显加速。
为什么?
因为很多方案只关注减少FLOP,忽略了内存访问的开销。
第二类,稀疏注意力。让每个token只关注有限的其他token,直接剪掉大量注意力连接。
这种方法确实减少了计算量,但稀疏模式本身也有内存访问的overhead,而且效果往往不如dense attention。
作者的判断
这篇文章的作者认为,现有方案没效果的根本原因不是算法不行,而是没有考虑GPU内存层级的IO特性。
现代GPU的内存层级是这样的:

DRAM(系统内存):容量最大(几十GB到几百GB),速度最慢(大约12.8 GB/s)
HBM(高带宽内存):GPU显存,容量中等(40-80 GB),速度中等(大约1.5-2.0 TB/s)
SRAM(片上缓存):容量最小(A100每个SM大约192 KB),速度最快(大约19 TB/s)
从HBM到SRAM,带宽差了一个数量级。
但现代GPU的计算速度已经超过内存速度了。操作越来越被内存访问(IO)而不是计算本身瓶颈住。
所以,关键问题不是FLOP多不多,而是有多少数据在HBM和SRAM之间来回搬。
这就是"IO-Awareness"的核心思想。
2.标准注意力实现的问题
标准算法:三步走
标准的注意力实现,通常是三步:
第一步:计算QK^T
把Q和K从HBM读到SRAM,在芯片上算QK^T,结果写回HBM。
这一步产生了一个N×N的注意力分数矩阵S。
第二步:Softmax
把S从HBM读出来,逐行做softmax,得到P矩阵。P再写回HBM。
第三步:PV
把P和V从HBM读出来,在芯片上算PV,结果写回HBM。
这三步看起来很自然对吧,但每步都在做一件事:把中间结果从HBM写出去,再从HBM读进来。
HBM访问次数的分析
让我们算一下HBM的访问总量。
前向传播:
第一步:读Q、K,写S→O(N²d+N²)次HBM访问
第二步:读S,写P→O(N²)次HBM访问
第三步:读P、V,写O→O(N²d+N²)次HBM访问
前向传播总共:O(N²d+N²)次HBM访问,是序列长度的平方级。
反向传播:
反向传播需要用到前向计算的S和P矩阵来计算梯度,所以同样的,要读S、P写回dQ、dK、dV。
反向传播也大约O(N²d+N²)次HBM访问。
核心矛盾
整个前向+反向传播,HBM访问总量大约是O(N²d)次。
但输入Q、K、V本身的总大小只有O(Nd),输出O也只有O(Nd)。
数据量是O(Nd)的东西,为什么要做O(N²d)的HBM访问?
多出来的O(N²)次访问,全是用在那个大得离谱的N×N注意力矩阵上。
这个矩阵太大,放不下SRAM,只能在HBM和SRAM之间反复搬。
这就是标准实现的根本问题。
3.FlashAttention算法:核心思路
两个关键技术
FlashAttention的思路很简单:用两个经典技术,避免把N×N注意力矩阵写到HBM上。
这两个技术是:
1.Tiling(分块计算)
2.Recomputation(重计算)
Tiling:分块做Softmax
标准softmax需要对整行做归一化,看起来必须把整行读进来才能算。
但math上有个技巧:softmax可以分块计算。
具体来说,如果我把一个向量x拆成两段x¹和x²,那么整个向量x的softmax结果,可以用x¹和x²各自的softmax统计量(最大值m和归一化因子ℓ)来逐步合并。

公式大概是:

这样,每次处理一个块,只需要记录两个小值(m和ℓ),就能把结果正确合并起来。
所以FlashAttention的做法是:
把K、V分成多个块
每次只把一个块加载到SRAM
对Q的每个块,和K的这个块算QK^T
在SRAM里算softmax,更新m和ℓ
逐步累积输出,最后写回HBM
关键:整个过程中,N×N注意力矩阵从来没有完整地出现在HBM上。
Recomputation:反向传播时不再读矩阵
那反向传播怎么办?
反向传播需要用到前向的S和P矩阵。标准做法是把前向的S和P存在HBM上,反向时直接读。
但FlashAttention说:不存了,反向时重新算。
它只存前向的输出O和softmax的统计量m、ℓ,这两个东西很小,O(Nd)的大小。
反向传播时,从HBM读Q、K、V,重新在SRAM里算S和P,然后再算梯度。
虽然多算了一些FLOP,但因为避免了从HBM读N×N矩阵的开销,实际运行时间反而更快。
这在学术上叫selective gradient checkpointing——梯度检查点的一种选择。
4.IO复杂度分析
理论保证
文章给出了一个严格分析。
标准注意力的HBM访问次数是Θ(N²d+N²)。
FlashAttention的HBM访问次数是Θ(N²d/M),其中M是SRAM的大小。
为什么是N²d/M?
因为SRAM能放下大小为Θ(M)的K、V块,每次能处理Θ(M/d)个K行。
对于N行的Q,需要N/(M/d)=Nd/M次扫描。每次扫描加载O(Nd)数据,所以总共O(N²d/M)次HBM访问。
实际差距
拿A100来算:
d=64(head维度)
M≈100KB(每个SM的SRAM)
标准注意力:O(N²×64)次HBM访问
FlashAttention:O(N²×64/100000)≈O(N²×0.00064)次HBM访问
HBM访问量减少了大约100倍。
虽然实际不可能完全达到理论极限(因为SRAM利用率、块大小选择等因素),但文章实验显示前向传播减少了约8倍,反向传播减少了约7倍,合计约9倍的HBM访问量降低。
下界证明
文章还证明了一个有意思的结论:
对于任何精确注意力算法,在所有可能的SRAM大小范围内,不可能渐近地优于O(N²d/M)的HBM访问下界。
换句话说,FlashAttention在这个意义上是最优的。
5.Block-Sparse FlashAttention
扩展:稀疏注意力
FlashAttention不只是精确注意力,还可以扩展到稀疏注意力。
思路很简单:如果注意力矩阵是块稀疏的(比如某些块全是零),那么在Tiling循环中直接跳过这些块就行。
算法跟FlashAttention几乎一样,只是加了一个if判断:如果当前块M_ij=0,跳过计算。
文章证明了Block-Sparse FlashAttention的HBM访问次数是Θ(N²d·s/M),其中s是非零块的比例。
s越小,加速越多。
实验显示,在LRA benchmark上,Block-Sparse FlashAttention相对于标准FlashAttention有2.8倍的加速,同时精度相当。

6.实验结果
训练速度
BERT-large:
在MLPerf 1.1上,FlashAttention比Nvidia记录快了15%(从20.0分钟降到17.4分钟)。
GPT-2 small:
比HuggingFace实现快3.5倍(从9.5天降到2.7天)
比Megatron-LM快2.0倍(从4.7天降到2.7天)
GPT-2 medium:
比HuggingFace实现快3.0倍(从21.0天降到6.9天)
比Megatron-LM快1.7倍(从11.5天降到6.9天)
Long-Range Arena:
平均加速2.4倍
模型质量提升
FlashAttention不只是更快,还能训练出更好的模型。
GPT-2长上下文:
用FlashAttention训练GPT-2 small,上下文长度从1K提升到4K,仍然比Megatron的1K版本快30%,且perplexity低了0.7。
长文档分类:
在MIMIC-III(医疗文本分类)和ECtHR(法律判决分类)上,增加序列长度带来显著提升:
MIMIC-III:16K序列比512序列提升4.3分
ECtHR:8K序列比512序列提升8.5分
PathFinder挑战:
Path-X(16K序列):FlashAttention的Transformer达到61.4%准确率,是第一个在这个任务上超过随机猜测的Transformer
Path-256(64K序列):Block-Sparse FlashAttention达到63.1%准确率
基准测试
在不同序列长度下:
序列长度128-512:FlashAttention比PyTorch标准实现快2-3倍
序列长度1024-2048:FlashAttention比所有近似注意力方法都快
内存占用:FlashAttention比PyTorch标准实现低20倍,比Linformer低2倍
不同硬件上的表现
A100:2-4倍加速
RTX 3090:2.5-4.5倍加速(HBM带宽更低,加速效果更明显)
T4:加速较少(SRAM更小,块大小需要更小)
7.总结
这篇文章的核心贡献,可以用一句话概括:
写注意力算法的时候,要把GPU内存层级的读写开销也算进去。
这个"IO-Awareness"的思想听起来简单,但在深度学习这个领域里,很少有人真正认真对待过。
大家习惯了看FLOP数,看理论复杂度,看benchmark上的accuracy。
但FLOP不等于wall-clock time,不等于内存使用量,不等于实际训练出来的模型质量。
FlashAttention用两个经典技术——Tiling和Recomputation——把注意力机制的HBM访问量从O(N²)降到了O(N²/M),在保持精确计算的同时实现了3倍的加速和20倍的内存节省。
而且它不只是更快,还让Transformer第一次真正具备了建模64K长度上下文的能力。
这就是IO-Awareness的力量。
