Flash Attention:让大模型推理快 2-4 倍的"魔法"
Flash Attention = 不改变任何计算结果,只改变计算方式,就让注意力机制快了 2-4 倍。
如果你用过 ChatGPT 或 Claude,你经历过等待光标一个字一个字蹦出来的时刻。生成慢的瓶颈往往不在 GPU 算力,而在内存墙。Flash Attention 就是拆掉这堵墙的”魔法”。
问题:注意力机制的内存墙
Section titled “问题:注意力机制的内存墙”注意力机制的计算很简单:Q × K^T → softmax → × V。但中间产生的 N×N 注意力矩阵是巨大的。
对于 128K 上下文的模型:128,000 × 128,000 × 2 字节 = 32 GB。仅仅一个中间矩阵。
更糟的是,标准实现需要把这个矩阵在 GPU 高速缓存(SRAM)和显存(HBM)之间反复搬运。现代 GPU 算力远超内存带宽,导致大量时间花在等数据搬运上。就像一个数学天才,每道题都要等别人把题目念完才能开始算。
核心思想:分块 + Online Softmax
Section titled “核心思想:分块 + Online Softmax”Flash Attention 的关键洞察:你不需要一次性看到完整的 N×N 矩阵才能算出正确结果。
把 Q、K、V 切成小块,每块大小刚好放进 SRAM。在 SRAM 中逐块计算,累积结果,最终只写回 N×d 的输出矩阵(远小于 N×N)。
Online Softmax
Section titled “Online Softmax”Softmax 需要知道所有分数的最大值,分块怎么算?答案——维护两个运行统计量:
m:当前已见分数的全局最大值ℓ:当前已见分数的指数和
每处理新块时更新这两个值,数学上可以证明逐步累积的结果与一次性计算完全等价。
# Online Softmax 更新(伪代码)new_max = max(old_max, block_max)new_sum = old_sum * exp(old_max - new_max) + sum(exp(block - new_max))| 内存类型 | 容量 | 带宽 | 延迟 |
|---|---|---|---|
| SRAM(片上) | ~20 MB | ~19 TB/s | 极低 |
| HBM(显存) | ~40-80 GB | ~1.5-2 TB/s | 较高 |
Flash Attention 显式地利用这个层级:在 SRAM 里做尽可能多的计算,减少 HBM 读写。就像厨师把所有食材搬到料理台上再开始做菜。
标准 Attention
Section titled “标准 Attention”def standard_attention(Q, K, V): scores = Q @ K.T # (N, N) → 写入 HBM attn = softmax(scores) # (N, N) → 从 HBM 读取 output = attn @ V # (N, d) return outputFlash Attention 思想
Section titled “Flash Attention 思想”def flash_attention(Q, K, V, block_size=128): output = zeros(N, d) row_max = full((N,), -inf) row_sum = zeros(N)
for j in range(0, N, block_size): # 分块遍历 K/V for i in range(0, N, block_size): # 分块遍历 Q scores = Q[i:i+B] @ K[j:j+B].T # SRAM 中计算 # Online Softmax 更新 row_max, row_sum # 累积 output return output实际生产中用 CUDA kernel 实现,远比伪代码高效,但核心思想一致。
- 速度 2-4×:同硬件推理速度直接翻倍
- 内存 O(N²) → O(N):长上下文成为可能
- 精度无损:数学上完全等价
- 行业标配:GPT-4、Claude、Llama、GLM 等几乎所有现代 LLM 都在用
Flash Attention 教给我们一个重要的工程哲学:有时候最大的性能提升不来自更好的算法,而来自更好地组织已有的算法。 它没有发明新数学,只是重新思考了”数据应该在哪里、什么时候被处理”。
这种 IO 感知的思维方式远不止应用于注意力机制——它是一种通用的系统优化范式。下次你遇到性能瓶颈时,不妨想想:是不是瓶颈不在计算,而在数据搬运?
参考资料:FlashAttention (Dao et al., 2022)、FlashAttention-2 (Dao, 2023)、FlashAttention-3 (Sheng et al., 2024)