跳转到内容
输入关键词后按 Enter 打开第一个结果。

Flash Attention:让大模型推理快 2-4 倍的"魔法"

Flash Attention = 不改变任何计算结果,只改变计算方式,就让注意力机制快了 2-4 倍。

如果你用过 ChatGPT 或 Claude,你经历过等待光标一个字一个字蹦出来的时刻。生成慢的瓶颈往往不在 GPU 算力,而在内存墙。Flash Attention 就是拆掉这堵墙的”魔法”。

注意力机制的计算很简单:Q × K^T → softmax → × V。但中间产生的 N×N 注意力矩阵是巨大的。

对于 128K 上下文的模型:128,000 × 128,000 × 2 字节 = 32 GB。仅仅一个中间矩阵。

更糟的是,标准实现需要把这个矩阵在 GPU 高速缓存(SRAM)和显存(HBM)之间反复搬运。现代 GPU 算力远超内存带宽,导致大量时间花在等数据搬运上。就像一个数学天才,每道题都要等别人把题目念完才能开始算。

Flash Attention 的关键洞察:你不需要一次性看到完整的 N×N 矩阵才能算出正确结果。

把 Q、K、V 切成小块,每块大小刚好放进 SRAM。在 SRAM 中逐块计算,累积结果,最终只写回 N×d 的输出矩阵(远小于 N×N)。

Softmax 需要知道所有分数的最大值,分块怎么算?答案——维护两个运行统计量:

  • m:当前已见分数的全局最大值
  • :当前已见分数的指数和

每处理新块时更新这两个值,数学上可以证明逐步累积的结果与一次性计算完全等价

# Online Softmax 更新(伪代码)
new_max = max(old_max, block_max)
new_sum = old_sum * exp(old_max - new_max) + sum(exp(block - new_max))
内存类型容量带宽延迟
SRAM(片上)~20 MB~19 TB/s极低
HBM(显存)~40-80 GB~1.5-2 TB/s较高

Flash Attention 显式地利用这个层级:在 SRAM 里做尽可能多的计算,减少 HBM 读写。就像厨师把所有食材搬到料理台上再开始做菜。

def standard_attention(Q, K, V):
scores = Q @ K.T # (N, N) → 写入 HBM
attn = softmax(scores) # (N, N) → 从 HBM 读取
output = attn @ V # (N, d)
return output
def flash_attention(Q, K, V, block_size=128):
output = zeros(N, d)
row_max = full((N,), -inf)
row_sum = zeros(N)
for j in range(0, N, block_size): # 分块遍历 K/V
for i in range(0, N, block_size): # 分块遍历 Q
scores = Q[i:i+B] @ K[j:j+B].T # SRAM 中计算
# Online Softmax 更新 row_max, row_sum
# 累积 output
return output

实际生产中用 CUDA kernel 实现,远比伪代码高效,但核心思想一致。

  1. 速度 2-4×:同硬件推理速度直接翻倍
  2. 内存 O(N²) → O(N):长上下文成为可能
  3. 精度无损:数学上完全等价
  4. 行业标配:GPT-4、Claude、Llama、GLM 等几乎所有现代 LLM 都在用

Flash Attention 教给我们一个重要的工程哲学:有时候最大的性能提升不来自更好的算法,而来自更好地组织已有的算法。 它没有发明新数学,只是重新思考了”数据应该在哪里、什么时候被处理”。

这种 IO 感知的思维方式远不止应用于注意力机制——它是一种通用的系统优化范式。下次你遇到性能瓶颈时,不妨想想:是不是瓶颈不在计算,而在数据搬运?


参考资料:FlashAttention (Dao et al., 2022)、FlashAttention-2 (Dao, 2023)、FlashAttention-3 (Sheng et al., 2024)