FlashAttention-3:通过异步和低精度实现快速准确的注意力机制

FlashAttention-3:通过异步和低精度实现快速准确的注意力机制

ArXiv ID: 2407.08608
作者: Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao
机构: Princeton University, NVIDIA
发布日期: 2024-07-11
硬件目标: NVIDIA H100 (Hopper) GPU


摘要

注意力机制是 Transformer 架构的核心,也是大语言模型和长上下文应用的性能瓶颈。FlashAttention-2 在 H100 GPU 上仅实现了 35% 的利用率,远未充分发挥硬件潜力。

本文提出的 FlashAttention-3 通过三项关键技术,在 H100 GPU 上实现了75% 的硬件利用率(FP16 达 740 TFLOPs/s),FP8 精度下接近1.2 PFLOPs/s,相比 FlashAttention-2 实现1.5-2.0 倍加速

三项核心技术

  1. Warp 特化:利用 Tensor Cores 和 TMA 的异步特性
  2. 操作交错:块级矩阵乘法和 softmax 交错执行
  3. FP8 支持:块量化和非相干处理

背景与动机

GPU 利用率现状

1
2
3
4
5
6
7
8
9
10
11
12
H100 GPU 理论性能 vs 实际利用率:

┌─────────────────────────────────────────────┐
│ H100 理论峰值:989 TFLOPs/s (FP16) │
│ │
│ 实际利用率对比: │
│ - cuBLAS (GEMM): ~80% │
│ - FlashAttention-2: ~35% │
│ - FlashAttention-3: ~75% ← 本文 │
│ │
│ 差距来源:内存访问、数据移动、同步开销 │
└─────────────────────────────────────────────┘

FlashAttention 演进

版本 发布时间 GPU 目标 利用率 特点
FlashAttention 2022 A100 ~40% 首次提出 IO 感知注意力
FlashAttention-2 2023 A100 ~55% 改进并行策略
FlashAttention-3 2024 H100 ~75% 异步 +FP8

Hopper 架构新特性

H100 相比 A100 的升级

特性 A100 H100 提升
Tensor Core FP16/BF16 +FP8 2x
内存带宽 1.6 TB/s 3.35 TB/s 2.1x
TMA 新特性
异步执行 有限 增强 -

TMA (Tensor Memory Accelerator)

  • 硬件加速的数据移动引擎
  • 与计算单元异步工作
  • 可隐藏数据传输延迟

FlashAttention-3 方法

整体架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
┌─────────────────────────────────────────────────────────┐
│ FlashAttention-3 Pipeline │
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Warp A │ │ Warp B │ │ Warp C │ │
│ │ (加载 K/V) │ │ (矩阵乘法) │ │ (Softmax) │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │ │ │ │
│ └──────────────────┼──────────────────┘ │
│ │ │
│ ┌───────────▼───────────┐ │
│ │ TMA 异步数据移动 │ │
│ │ (与计算重叠) │ │
│ └───────────────────────┘ │
│ │ │
│ ┌───────────▼───────────┐ │
│ │ FP8 Tensor Cores │ │
│ │ (块量化 + 非相干) │ │
│ └───────────────────────┘ │
└─────────────────────────────────────────────────────────┘

技术 1:Warp 特化

核心思想:将不同任务分配给不同的 Warp,实现计算与数据移动的重叠

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
传统方式 vs Warp 特化:

传统方式(串行):
[加载 Q] → [加载 K,V] → [计算 QK^T] → [Softmax] → [计算 PV]
│ │ │ │ │
└──────────┴─────────────┴───────────┴──────────┘
顺序执行

Warp 特化(并行):
Warp 0: [加载 Q] ────── [加载 Q] ────── [加载 Q] ────── ...
Warp 1: [加载 K,V] ────── [加载 K,V] ────── ...
Warp 2: [计算 QK^T] ────── [计算 QK^T] ── ...
Warp 3: [Softmax] ────── [Softmax] ...

时间 →

代码示意

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
__global__ void flash_attention_v3(Q, K, V, O) {
// Warp 特化:不同 Warp 负责不同任务
int warp_id = (threadIdx.x / 32) % 4;

// Warp 0: 加载 Q
if (warp_id == 0) {
load_Q_async(Q, q_shared);
}
// Warp 1: 加载 K, V
else if (warp_id == 1) {
load_KV_async(K, V, k_shared, v_shared);
}
// Warp 2: 矩阵乘法
else if (warp_id == 2) {
mma_sync(q_shared, k_shared, acc);
}
// Warp 3: Softmax
else if (warp_id == 3) {
softmax_inplace(acc);
}

// 同步点
__syncthreads();
}

性能收益

  • 隐藏数据加载延迟:~30%
  • 提高 Warp 占用率:2x
  • 减少空闲周期:~40%

技术 2:操作交错

问题:传统实现中,矩阵乘法和 Softmax 是分离的两个阶段

1
2
3
4
5
6
7
8
9
10
传统流程:
┌──────────────┐ ┌──────────────┐
│ 全部 QK^T │ -> │ 全部 Softmax │
│ 矩阵乘法 │ │ 计算 │
└──────────────┘ └──────────────┘
阶段 1 阶段 2

问题:
- 需要存储完整的注意力矩阵(O(n²) 内存)
- Softmax 必须等待所有 QK^T 完成

FlashAttention-3 方案:块级交错执行

1
2
3
4
5
6
7
8
9
10
交错流程:
┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐
│ QK^T│->│Soft │->│ QK^T│->│Soft │-> ...
│ blk0│ │max0 │->│ blk1│ │max1 │->
└─────┘ └─────┘ └─────┘ └─────┘

优势:
- 内存复杂度降至 O(n)
- 提前开始 Softmax 计算
- 更好的数据局部性

代码示意

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 块级交错执行
for (int block_idx = 0; block_idx < num_blocks; block_idx++) {
// 步骤 1: 计算当前块的 QK^T
mma_sync(q_block, k_blocks[block_idx], acc);

// 步骤 2: 立即执行 Softmax(不需要等待其他块)
softmax_block(acc, temperature);

// 步骤 3: 与 V 相乘
mma_sync(acc, v_blocks[block_idx], output);

// 步骤 4: 累积到最终结果
atomicAdd(output_global, output);
}

技术 3:FP8 支持

FP8 格式

1
2
3
4
5
6
7
8
9
10
11
12
13
FP8 数据类型:

E4M3 (4 指数位,3 尾数位):
├─ 1 bit: 符号
├─ 4 bits: 指数
└─ 3 bits: 尾数
范围:[4.88e-4, 448]

E5M2 (5 指数位,2 尾数位):
├─ 1 bit: 符号
├─ 5 bits: 指数
└─ 2 bits: 尾数
范围:[6.10e-5, 57344]

挑战:直接 FP8 量化会导致较大数值误差

解决方案

  1. 块量化:每个块独立量化,减小动态范围
  2. 非相干处理:通过随机舍入减少系统性偏差
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// 块级量化
__device__ void quantize_block_fp8(float* data, fp8_t* output, int block_size) {
// 找到块内最大值
float max_val = 0;
for (int i = 0; i < block_size; i++) {
max_val = fmaxf(max_val, fabsf(data[i]));
}

// 计算量化因子
float scale = max_val / FP8_MAX;

// 量化(带随机舍入)
for (int i = 0; i < block_size; i++) {
float normalized = data[i] / scale;
output[i] = stochastic_round_fp8(normalized);
}
}

// 随机舍入减少偏差
__device__ fp8_t stochastic_round_fp8(float x) {
float lower = floorf(x * 256) / 256.0f;
float upper = ceilf(x * 256) / 256.0f;
float prob = (x * 256 - floorf(x * 256));

// 随机舍入:以概率 prob 向上舍入
float rand_val = random_uniform();
return (rand_val < prob) ? upper : lower;
}

数值误差对比

方法 MSE (相对 FP16) 说明
直接 FP8 2.6x 系统性偏差大
FP8 + 块量化 1.0x 误差最低
FP8 + 非相干 1.2x 接近块量化

实验结果

性能对比

H100 GPU - FP16

方法 TFLOPs/s 利用率 相对加速
PyTorch Attention 180 18% 1.0x
FlashAttention 290 29% 1.6x
FlashAttention-2 350 35% 1.9x
FlashAttention-3 740 75% 4.1x

H100 GPU - FP8

方法 TFLOPs/s 精度损失 相对加速
FlashAttention-2 (FP16) 350 0% 1.0x
FlashAttention-3 (FP8) 1,180 <1% 3.4x

A100 GPU 对比

方法 TFLOPs/s 相对加速
FlashAttention-2 195 1.0x
FlashAttention-3 285 1.46x

端到端训练加速

模型 序列长度 基线 FA-2 FA-3 加速
GPT-2 XL 1K 145 178 265 1.83x
LLaMA-7B 2K 98 124 192 1.96x
LLaMA-13B 4K 52 68 108 2.08x

长序列性能

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
注意力计算时间 vs 序列长度(LLaMA-7B):

时间 (ms)

50 │ PyTorch ●

40 │
│ ● FlashAttn-2
30 │

20 │ ● FlashAttn-3

10 │
└─────────────────────────────
1K 2K 4K 8K
序列长度

内存效率

序列长度 FA-2 内存 FA-3 内存 节省
1K 2.1 GB 1.8 GB 14%
4K 8.4 GB 5.2 GB 38%
8K 16.8 GB 8.9 GB 47%
16K OOM 15.2 GB -

消融实验

各组件贡献

配置 TFLOPs/s 相对性能
完整 FlashAttention-3 740 100%
- Warp 特化 520 70%
- 操作交错 580 78%
- FP8 量化 380 51%
全部移除 (FA-2) 350 47%

异步效果

1
2
3
4
5
6
7
8
9
10
11
Warp 特化带来的延迟隐藏:

无异步:
[加载]████████ [计算]████████ [存储]████
│等待加载完成│

有异步:
[加载]████████████████████████
[计算] ████████████████
[存储] ████████
时间利用率:~85%

实际应用

与 Transformers 集成

1
2
3
4
5
6
7
8
9
10
11
12
13
from flash_attn import flash_attention_v3
from transformers import AutoModelForCausalLM

# 加载模型
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b")

# 启用 FlashAttention-3
model.config.use_flash_attn = True
model.config.flash_attn_version = 3
model.config.flash_attn_fp8 = True # 启用 FP8

# 训练/推理
outputs = model(input_ids, attention_mask=attention_mask)

性能监控

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
from flash_attn.utils.benchmark import benchmark_forward

# Benchmark
model = ...
input_ids = torch.randint(0, 10000, (8, 4096)).cuda()

# FlashAttention-2
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False):
t_fa2 = benchmark_forward(model, input_ids)
print(f"FA-2: {t_fa2.avg_time:.2f}ms")

# FlashAttention-3
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False):
model.config.flash_attn_version = 3
t_fa3 = benchmark_forward(model, input_ids)
print(f"FA-3: {t_fa3.avg_time:.2f}ms")

print(f"加速比:{t_fa2.avg_time / t_fa3.avg_time:.2f}x")

局限性与未来方向

当前局限

局限 影响 缓解方案
仅支持 H100 A100 用户无法使用全部特性 回退到 FA-2
FP8 精度损失 某些任务精度下降 混合精度
实现复杂度高 调试困难 开源社区支持

未来方向

  1. 多 GPU 扩展:跨 GPU 注意力
  2. 新硬件支持:B100、 Rubin 架构
  3. 更低精度:FP4 探索
  4. 稀疏注意力:与稀疏性结合

总结

FlashAttention-3 通过深入挖掘 Hopper GPU 架构特性,实现了 75% 的硬件利用率:

核心贡献

  1. Warp 特化利用异步执行隐藏延迟
  2. 操作交错改善内存局部性
  3. FP8 块量化实现低精度高效计算

实际影响

  • H100 上 1.5-2.0 倍加速
  • FP8 精度下接近 1.2 PFLOPs/s
  • 显著降低 LLM 训练成本

资源


评分: 4.7/5.0 ⭐⭐⭐⭐⭐

推荐度: 强烈推荐。H100 用户必备优化。

© 2026 Generative AI Discovery All Rights Reserved.
Theme by hiero