Flash Attention:用 IO 感知重写注意力的底层算法
上一篇我们聊了 KV Cache,解决的是「别重复算」的问题。今天聊一个更底层的优化:就算你必须算,怎么让 Attention 本身快 3 倍?
答案是 Flash Attention —— Tri Dao 在 2022 年提出的 IO-aware exact attention 算法。注意关键词:exact,它不是近似,不丢精度,纯粹通过重新编排内存访问模式来加速。
问题:Attention 的瓶颈不是算力,是带宽#
标准 Self-Attention 的计算流程:
Q, K, V ∈ R^{N×d}
S = QK^T # N×N 矩阵
P = softmax(S) # N×N 矩阵
O = PV # N×d 矩阵看起来瓶颈是 O(N²) 的计算量?不完全对。
现代 GPU 的算力增长远快于内存带宽。以 A100 为例:
| 指标 | 数值 |
|---|---|
| FP16 算力 | 312 TFLOPS |
| HBM 带宽 | 2.0 TB/s |
| SRAM 容量 | 20 MB |
| HBM 容量 | 80 GB |
算术强度(Arithmetic Intensity) = FLOPS / Bytes。当一个 kernel 的算术强度低于 GPU 的 ops:byte 比时,它就是 memory-bound 的。
标准 Attention 需要把 N×N 的中间矩阵 S 和 P 写入 HBM 再读回来 —— 这两次 round-trip 就是瓶颈。当 N=4096, d=128 时:
- S 矩阵大小:4096² × 2 bytes = 32 MB
- 写入 HBM + 读回 = 64 MB 的带宽开销
- 实际计算只需要几毫秒,但搬数据要十几毫秒
Flash Attention 的核心洞察:如果我们能把 Attention 的所有计算都放在 SRAM 里完成,永远不把 N×N 矩阵写入 HBM,速度就能起飞。
核心算法:Tiling + Online Softmax#
问题在于 SRAM 只有 20MB,装不下完整的 N×N 矩阵。Flash Attention 的解决方案是 分块计算(Tiling):
分块策略#
把 Q 分成 T_r 块,K/V 分成 T_c 块,每块大小为 B_r × d 和 B_c × d,确保每个 block 的中间结果能放进 SRAM。
# 伪代码:Flash Attention 核心循环
for i in range(T_r): # 遍历 Q 的每个块
q_block = Q[i * B_r : (i+1) * B_r] # 从 HBM 加载 Q 块
# 初始化累加器
m_i = -inf # running max
l_i = 0 # running sum of exp
o_i = 0 # running output
for j in range(T_c): # 遍历 K/V 的每个块
k_block = K[j * B_c : (j+1) * B_c] # 从 HBM 加载 K 块
为什么 Softmax 是难点?#
Softmax 需要全局信息:softmax(x_i) = exp(x_i) / Σ exp(x_j),你得知道所有 x 才能算分母。
Online Softmax(Milakov & Gimelshein, 2018)解决了这个问题:维护一个 running max m 和 running sum l,每处理一个新块就用指数缩放因子修正之前的累积值。
数学上等价于:
当新块的 max 为 m_new 时:
旧的 exp 值全部乘以 exp(m_old - m_new) 来修正这个技巧让我们可以 一次遍历 就完成 softmax,不需要存储完整的 N×N 矩阵。
IO 复杂度分析#
这是 Flash Attention 论文最硬核的部分。
标准 Attention 的 HBM 访问量:
Θ(Nd + N²) # 读写 Q,K,V,O 是 O(Nd),读写 S,P 是 O(N²)Flash Attention 的 HBM 访问量:
Θ(N²d² / M) # M 是 SRAM 大小当 M 足够大时(通常成立),Flash Attention 的 IO 量远小于标准实现。论文还证明了这个界是 渐近最优 的 —— 你不可能做得更好了。
具体数字(N=4096, d=128, M=20MB):
- 标准 Attention:~96 MB HBM 访问
- Flash Attention:~12 MB HBM 访问
- 减少 8 倍 IO
Flash Attention 2:把 GPU 榨干#
Flash Attention 1 已经很快了,但 GPU 利用率只有约 50-70%。Flash Attention 2(2023)做了几个关键改进:
1. 减少非矩阵乘法运算#
现代 GPU 有专门的 Tensor Core 做矩阵乘法(GEMM),吞吐量比普通 FP32 算术高 16 倍。FA2 重新编排计算,把尽量多的工作交给 Tensor Core。
2. 优化并行策略#
FA1 在 batch 和 head 维度并行。当 batch × heads 不够大时,SM(Streaming Multiprocessor)利用率低。
FA2 额外在序列长度维度并行 —— 把不同的 Q 块分配到不同的 thread block。
3. 优化 warp 间通信#
FA1 中,一个 thread block 内的多个 warp 需要通过 shared memory 同步中间结果。FA2 让每个 warp 独立计算部分结果,最后才合并,减少了 warp 间的同步开销。
结果:FA2 在 A100 上达到 230 TFLOPS,约 73% 的理论峰值利用率。
Flash Attention 3:Hopper 架构的深度适配#
Flash Attention 3(2024)针对 NVIDIA H100(Hopper 架构)做了进一步优化:
核心技术:
1. Warp-specialization:生产者 warp 负责数据搬运,消费者 warp 负责计算
2. Pingpong scheduling:交替使用寄存器,隐藏 GEMM 延迟
3. FP8 支持:利用 H100 的 FP8 Tensor Core,配合 incoherent processing 保持精度在 H100 上,FA3 达到了约 740 TFLOPS(FP16),接近 75% 的峰值利用率。
工程实践:怎么用?#
PyTorch 原生支持#
从 PyTorch 2.0 开始,torch.nn.functional.scaled_dot_product_attention 会自动调用 Flash Attention:
import torch
import torch.nn.functional as F
# 自动选择最优 backend(Flash Attention / Memory-Efficient / Math)
q = torch.randn(2, 8, 4096, 128, device="cuda", dtype=torch.float16)
k = torch.randn(2, 8, 4096, 128, device="cuda", dtype=torch.float16)
v = torch.randn(2, 8
直接使用 flash-attn 库#
from flash_attn import flash_attn_func
# q, k, v: (batch, seqlen, nheads, headdim)
output = flash_attn_func(q, k, v, causal=True)
# 带 window attention(sliding window)
output = flash_attn_func(q, k, v, causal=True, window_size=(512, 512))性能对比实测#
import time
import torch
def benchmark_attention(N, d, num_heads=8, batch=4):
q = torch.randn(batch, num_heads, N, d, device="cuda", dtype=torch.float16)
k = torch.randn(batch, num_heads, N, d, device="cuda", dtype=torch.float16)
v
和 KV Cache 的关系#
如果你看了上一篇 KV Cache 文章,可能会问:Flash Attention 和 KV Cache 冲突吗?
完全不冲突,它们解决不同阶段的问题:
- KV Cache:推理时避免重复计算历史 token 的 K、V
- Flash Attention:不管是训练还是推理,让 Attention 计算本身更快
实际上,vLLM、TensorRT-LLM 等推理框架同时使用了两者:KV Cache 减少计算量,Flash Attention(特别是 PagedAttention 变体)让剩余的计算尽可能高效。
延伸:Ring Attention 与分布式长序列#
Flash Attention 的 tiling 思想还催生了 Ring Attention(2023),把分块计算扩展到多 GPU 上:
- 每个 GPU 持有一段 Q,K/V 在 GPU 之间以 ring 拓扑传递
- 计算和通信完全重叠
- 理论上序列长度可以随 GPU 数量线性扩展
这是 Gemini 1.5(1M context)等超长上下文模型背后的关键技术之一。
总结#
| 版本 | 年份 | 核心创新 | 目标硬件 |
|---|---|---|---|
| FA1 | 2022 | IO-aware tiling + online softmax | A100 |
| FA2 | 2023 | 优化并行与 warp 通信 | A100 |
| FA3 | 2024 | Warp specialization + FP8 | H100 |
Flash Attention 的成功说明了一个深刻的道理:在现代硬件上,算法的 IO 复杂度比计算复杂度更重要。 理解你的硬件内存层级,然后围绕它设计算法 —— 这才是真正的系统级思维。
参考文献:
- Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness", 2022
- Dao, "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning", 2023
- Shah et al., "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision", 2024
- Liu et al., "Ring Attention with Blockwise Transformers for Near-Infinite Context", 2023