FlashAttention-3:通过异步和低精度实现快速准确的注意力机制
ArXiv ID: 2407.08608
作者: Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao
机构: Princeton University, NVIDIA
发布日期: 2024-07-11
硬件目标: NVIDIA H100 (Hopper) GPU
摘要
注意力机制是 Transformer 架构的核心,也是大语言模型和长上下文应用的性能瓶颈。FlashAttention-2 在 H100 GPU 上仅实现了 35% 的利用率,远未充分发挥硬件潜力。
本文提出的 FlashAttention-3 通过三项关键技术,在 H100 GPU 上实现了75% 的硬件利用率(FP16 达 740 TFLOPs/s),FP8 精度下接近1.2 PFLOPs/s,相比 FlashAttention-2 实现1.5-2.0 倍加速。
三项核心技术:
- Warp 特化:利用 Tensor Cores 和 TMA 的异步特性
- 操作交错:块级矩阵乘法和 softmax 交错执行
- FP8 支持:块量化和非相干处理
背景与动机
GPU 利用率现状
1 2 3 4 5 6 7 8 9 10 11 12
| H100 GPU 理论性能 vs 实际利用率:
┌─────────────────────────────────────────────┐ │ H100 理论峰值:989 TFLOPs/s (FP16) │ │ │ │ 实际利用率对比: │ │ - cuBLAS (GEMM): ~80% │ │ - FlashAttention-2: ~35% │ │ - FlashAttention-3: ~75% ← 本文 │ │ │ │ 差距来源:内存访问、数据移动、同步开销 │ └─────────────────────────────────────────────┘
|
FlashAttention 演进
| 版本 |
发布时间 |
GPU 目标 |
利用率 |
特点 |
| FlashAttention |
2022 |
A100 |
~40% |
首次提出 IO 感知注意力 |
| FlashAttention-2 |
2023 |
A100 |
~55% |
改进并行策略 |
| FlashAttention-3 |
2024 |
H100 |
~75% |
异步 +FP8 |
Hopper 架构新特性
H100 相比 A100 的升级:
| 特性 |
A100 |
H100 |
提升 |
| Tensor Core |
FP16/BF16 |
+FP8 |
2x |
| 内存带宽 |
1.6 TB/s |
3.35 TB/s |
2.1x |
| TMA |
❌ |
✅ |
新特性 |
| 异步执行 |
有限 |
增强 |
- |
TMA (Tensor Memory Accelerator):
- 硬件加速的数据移动引擎
- 与计算单元异步工作
- 可隐藏数据传输延迟
FlashAttention-3 方法
整体架构
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| ┌─────────────────────────────────────────────────────────┐ │ FlashAttention-3 Pipeline │ │ │ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │ │ Warp A │ │ Warp B │ │ Warp C │ │ │ │ (加载 K/V) │ │ (矩阵乘法) │ │ (Softmax) │ │ │ └─────────────┘ └─────────────┘ └─────────────┘ │ │ │ │ │ │ │ └──────────────────┼──────────────────┘ │ │ │ │ │ ┌───────────▼───────────┐ │ │ │ TMA 异步数据移动 │ │ │ │ (与计算重叠) │ │ │ └───────────────────────┘ │ │ │ │ │ ┌───────────▼───────────┐ │ │ │ FP8 Tensor Cores │ │ │ │ (块量化 + 非相干) │ │ │ └───────────────────────┘ │ └─────────────────────────────────────────────────────────┘
|
技术 1:Warp 特化
核心思想:将不同任务分配给不同的 Warp,实现计算与数据移动的重叠
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| 传统方式 vs Warp 特化:
传统方式(串行): [加载 Q] → [加载 K,V] → [计算 QK^T] → [Softmax] → [计算 PV] │ │ │ │ │ └──────────┴─────────────┴───────────┴──────────┘ 顺序执行
Warp 特化(并行): Warp 0: [加载 Q] ────── [加载 Q] ────── [加载 Q] ────── ... Warp 1: [加载 K,V] ────── [加载 K,V] ────── ... Warp 2: [计算 QK^T] ────── [计算 QK^T] ── ... Warp 3: [Softmax] ────── [Softmax] ...
时间 →
|
代码示意:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
| __global__ void flash_attention_v3(Q, K, V, O) { // Warp 特化:不同 Warp 负责不同任务 int warp_id = (threadIdx.x / 32) % 4;
// Warp 0: 加载 Q if (warp_id == 0) { load_Q_async(Q, q_shared); } // Warp 1: 加载 K, V else if (warp_id == 1) { load_KV_async(K, V, k_shared, v_shared); } // Warp 2: 矩阵乘法 else if (warp_id == 2) { mma_sync(q_shared, k_shared, acc); } // Warp 3: Softmax else if (warp_id == 3) { softmax_inplace(acc); }
// 同步点 __syncthreads(); }
|
性能收益:
- 隐藏数据加载延迟:~30%
- 提高 Warp 占用率:2x
- 减少空闲周期:~40%
技术 2:操作交错
问题:传统实现中,矩阵乘法和 Softmax 是分离的两个阶段
1 2 3 4 5 6 7 8 9 10
| 传统流程: ┌──────────────┐ ┌──────────────┐ │ 全部 QK^T │ -> │ 全部 Softmax │ │ 矩阵乘法 │ │ 计算 │ └──────────────┘ └──────────────┘ 阶段 1 阶段 2
问题: - 需要存储完整的注意力矩阵(O(n²) 内存) - Softmax 必须等待所有 QK^T 完成
|
FlashAttention-3 方案:块级交错执行
1 2 3 4 5 6 7 8 9 10
| 交错流程: ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ │ QK^T│->│Soft │->│ QK^T│->│Soft │-> ... │ blk0│ │max0 │->│ blk1│ │max1 │-> └─────┘ └─────┘ └─────┘ └─────┘
优势: - 内存复杂度降至 O(n) - 提前开始 Softmax 计算 - 更好的数据局部性
|
代码示意:
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| // 块级交错执行 for (int block_idx = 0; block_idx < num_blocks; block_idx++) { // 步骤 1: 计算当前块的 QK^T mma_sync(q_block, k_blocks[block_idx], acc);
// 步骤 2: 立即执行 Softmax(不需要等待其他块) softmax_block(acc, temperature);
// 步骤 3: 与 V 相乘 mma_sync(acc, v_blocks[block_idx], output);
// 步骤 4: 累积到最终结果 atomicAdd(output_global, output); }
|
技术 3:FP8 支持
FP8 格式:
1 2 3 4 5 6 7 8 9 10 11 12 13
| FP8 数据类型:
E4M3 (4 指数位,3 尾数位): ├─ 1 bit: 符号 ├─ 4 bits: 指数 └─ 3 bits: 尾数 范围:[4.88e-4, 448]
E5M2 (5 指数位,2 尾数位): ├─ 1 bit: 符号 ├─ 5 bits: 指数 └─ 2 bits: 尾数 范围:[6.10e-5, 57344]
|
挑战:直接 FP8 量化会导致较大数值误差
解决方案:
- 块量化:每个块独立量化,减小动态范围
- 非相干处理:通过随机舍入减少系统性偏差
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| // 块级量化 __device__ void quantize_block_fp8(float* data, fp8_t* output, int block_size) { // 找到块内最大值 float max_val = 0; for (int i = 0; i < block_size; i++) { max_val = fmaxf(max_val, fabsf(data[i])); }
// 计算量化因子 float scale = max_val / FP8_MAX;
// 量化(带随机舍入) for (int i = 0; i < block_size; i++) { float normalized = data[i] / scale; output[i] = stochastic_round_fp8(normalized); } }
// 随机舍入减少偏差 __device__ fp8_t stochastic_round_fp8(float x) { float lower = floorf(x * 256) / 256.0f; float upper = ceilf(x * 256) / 256.0f; float prob = (x * 256 - floorf(x * 256));
// 随机舍入:以概率 prob 向上舍入 float rand_val = random_uniform(); return (rand_val < prob) ? upper : lower; }
|
数值误差对比:
| 方法 |
MSE (相对 FP16) |
说明 |
| 直接 FP8 |
2.6x |
系统性偏差大 |
| FP8 + 块量化 |
1.0x |
误差最低 |
| FP8 + 非相干 |
1.2x |
接近块量化 |
实验结果
性能对比
H100 GPU - FP16
| 方法 |
TFLOPs/s |
利用率 |
相对加速 |
| PyTorch Attention |
180 |
18% |
1.0x |
| FlashAttention |
290 |
29% |
1.6x |
| FlashAttention-2 |
350 |
35% |
1.9x |
| FlashAttention-3 |
740 |
75% |
4.1x |
H100 GPU - FP8
| 方法 |
TFLOPs/s |
精度损失 |
相对加速 |
| FlashAttention-2 (FP16) |
350 |
0% |
1.0x |
| FlashAttention-3 (FP8) |
1,180 |
<1% |
3.4x |
A100 GPU 对比
| 方法 |
TFLOPs/s |
相对加速 |
| FlashAttention-2 |
195 |
1.0x |
| FlashAttention-3 |
285 |
1.46x |
端到端训练加速
| 模型 |
序列长度 |
基线 |
FA-2 |
FA-3 |
加速 |
| GPT-2 XL |
1K |
145 |
178 |
265 |
1.83x |
| LLaMA-7B |
2K |
98 |
124 |
192 |
1.96x |
| LLaMA-13B |
4K |
52 |
68 |
108 |
2.08x |
长序列性能
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| 注意力计算时间 vs 序列长度(LLaMA-7B):
时间 (ms) │ 50 │ PyTorch ● │ 40 │ │ ● FlashAttn-2 30 │ │ 20 │ ● FlashAttn-3 │ 10 │ └───────────────────────────── 1K 2K 4K 8K 序列长度
|
内存效率
| 序列长度 |
FA-2 内存 |
FA-3 内存 |
节省 |
| 1K |
2.1 GB |
1.8 GB |
14% |
| 4K |
8.4 GB |
5.2 GB |
38% |
| 8K |
16.8 GB |
8.9 GB |
47% |
| 16K |
OOM |
15.2 GB |
- |
消融实验
各组件贡献
| 配置 |
TFLOPs/s |
相对性能 |
| 完整 FlashAttention-3 |
740 |
100% |
| - Warp 特化 |
520 |
70% |
| - 操作交错 |
580 |
78% |
| - FP8 量化 |
380 |
51% |
| 全部移除 (FA-2) |
350 |
47% |
异步效果
1 2 3 4 5 6 7 8 9 10 11
| Warp 特化带来的延迟隐藏:
无异步: [加载]████████ [计算]████████ [存储]████ │等待加载完成│
有异步: [加载]████████████████████████ [计算] ████████████████ [存储] ████████ 时间利用率:~85%
|
实际应用
1 2 3 4 5 6 7 8 9 10 11 12 13
| from flash_attn import flash_attention_v3 from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b")
model.config.use_flash_attn = True model.config.flash_attn_version = 3 model.config.flash_attn_fp8 = True
outputs = model(input_ids, attention_mask=attention_mask)
|
性能监控
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| import torch from flash_attn.utils.benchmark import benchmark_forward
model = ... input_ids = torch.randint(0, 10000, (8, 4096)).cuda()
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False): t_fa2 = benchmark_forward(model, input_ids) print(f"FA-2: {t_fa2.avg_time:.2f}ms")
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False): model.config.flash_attn_version = 3 t_fa3 = benchmark_forward(model, input_ids) print(f"FA-3: {t_fa3.avg_time:.2f}ms")
print(f"加速比:{t_fa2.avg_time / t_fa3.avg_time:.2f}x")
|
局限性与未来方向
当前局限
| 局限 |
影响 |
缓解方案 |
| 仅支持 H100 |
A100 用户无法使用全部特性 |
回退到 FA-2 |
| FP8 精度损失 |
某些任务精度下降 |
混合精度 |
| 实现复杂度高 |
调试困难 |
开源社区支持 |
未来方向
- 多 GPU 扩展:跨 GPU 注意力
- 新硬件支持:B100、 Rubin 架构
- 更低精度:FP4 探索
- 稀疏注意力:与稀疏性结合
总结
FlashAttention-3 通过深入挖掘 Hopper GPU 架构特性,实现了 75% 的硬件利用率:
核心贡献:
- Warp 特化利用异步执行隐藏延迟
- 操作交错改善内存局部性
- FP8 块量化实现低精度高效计算
实际影响:
- H100 上 1.5-2.0 倍加速
- FP8 精度下接近 1.2 PFLOPs/s
- 显著降低 LLM 训练成本
资源
评分: 4.7/5.0 ⭐⭐⭐⭐⭐
推荐度: 强烈推荐。H100 用户必备优化。