SageAttention3: 基于微缩放 FP4 的 Attention 推理加速与 8-bit 训练探索

SageAttention3: 基于微缩放 FP4 的 Attention 推理加速与 8-bit 训练探索

ArXiv ID: 2505.11594
作者: Jintao Zhang, Jia Wei, Pengle Zhang et al.
机构: Tsinghua University
发布日期: 2025-05-16
目标硬件: NVIDIA Blackwell GPU (RTX 50 系列)


摘要

SageAttention3 是针对新一代 Blackwell GPU 的 FP4 Tensor Cores 设计的高效 attention 加速方案。该论文提出了针对推理场景的 FP4 量化 attention 机制,并首次探索了训练阶段的低 bit attention。在 RTX 5090 上实现了1038 TOPS的性能,相比最快的 FlashAttention 实现提升5 倍。此外,论文还开发了准确高效的 8-bit attention 用于前向和反向传播,在微调任务中实现无损性能。


FP4 量化的机遇与挑战

Blackwell 架构新特性

1
2
3
4
5
6
7
8
9
10
11
12
NVIDIA GPU 架构演进:

架构 | Tensor Core 精度 | FP4 支持 | 代表产品
----------|-----------------|---------|------------
Ampere | FP16/BF16/INT8 | ❌ | A100
Hopper | FP8/FP16/BF16 | ❌ | H100
Blackwell | FP4/FP8/FP16 | ✅ | B100/RTX5090

FP4 Tensor Core 优势:
- 计算密度:相比 FP8 再翻 2 倍
- 内存带宽:节省 75% 带宽
- 理论峰值:1000+ TOPS

FP4 量化的挑战

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
FP4 数值范围限制:

FP4 E2M1 格式:
- 符号位:1 bit
- 指数位:2 bits
- 尾数位:1 bit

可表示值:{0, ±0.5, ±1, ±2, ±4, ±6}

问题:
┌─────────────────────────────────────────┐
│ 1. 动态范围极小 (仅 6 个非零值) │
│ 2. 量化误差大 │
│ 3. 注意力分数失真 │
│ 4. Softmax 数值不稳定 │
└─────────────────────────────────────────┘

SageAttention3 方法

整体架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
┌─────────────────────────────────────────────────────────┐
│ SageAttention3 Architecture │
│ │
│ 输入 Q, K, V (FP16/BF16) │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Microscaling │ │
│ │ FP4 量化 │ ← 微缩放量化 │
│ └─────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ FP4 Tensor │ │
│ │ Core GEMM │ ← Blackwell FP4 TC │
│ └─────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Dequant + │ │
│ │ Softmax │ ← 高精度 softmax │
│ └─────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ FP4 PV │ │
│ │ Computation │ ← PV 计算 │
│ └─────────────────┘ │
│ │ │
│ ▼ │
│ 输出 (FP16/BF16) │
└─────────────────────────────────────────────────────────┘

微缩放 FP4 量化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch

class MicroscalingFP4:
"""
微缩放 FP4 量化

核心思想:
使用 per-group 缩放因子,将一组数值缩放到 FP4 可表示的范围
"""

def __init__(self, group_size=16):
self.group_size = group_size
# FP4 可表示的最大值 (E2M1 格式)
self.fp4_max = 6.0

def quantize(self, tensor: torch.Tensor) -> torch.Tensor:
"""
微缩放 FP4 量化

步骤:
1. 将 tensor 分成大小为 group_size 的组
2. 对每组计算缩放因子 scale = max(abs(group)) / fp4_max
3. 应用缩放并量化到 FP4
4. 存储量化值和缩放因子
"""
original_shape = tensor.shape
tensor = tensor.flatten()

# 分组
num_groups = (len(tensor) + self.group_size - 1) // self.group_size
padded_length = num_groups * self.group_size

# padding
padded_tensor = torch.zeros(padded_length, device=tensor.device)
padded_tensor[:len(tensor)] = tensor

# 重塑为 (num_groups, group_size)
groups = padded_tensor.reshape(num_groups, self.group_size)

# 计算每组的缩放因子
scales = groups.abs().max(dim=1, keepdim=True).values / self.fp4_max
scales = scales.clamp(min=1e-6) # 避免除零

# 应用缩放
scaled_groups = groups / scales

# 量化到 FP4 (简化为 INT4)
# 实际 FP4 有特殊的编码格式
quantized = scaled_groups.clamp(-6, 6).round()

# 存储为 INT4 打包格式
quantized = quantized.to(torch.int8)

return quantized.reshape(original_shape), scales

def dequantize(self, quantized: torch.Tensor,
scales: torch.Tensor) -> torch.Tensor:
"""
FP4 反量化

关键:缩放因子是 power-of-two 时,可用位移优化
"""
quantized = quantized.flatten().to(torch.float32)
scales = scales.flatten()

# 反量化
dequantized = quantized * scales.repeat_interleave(self.group_size)

return dequantized

FP4 Attention 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
// CUDA kernel 伪代码
__global__ void sage_attention_fp4(
const half* Q, const half* K, const half* V,
half* Output,
int batch_size, int num_heads, int seq_len, int head_dim
) {
// 每个 thread block 处理一个 query 块

// 步骤 1: 加载 Q 并量化到 FP4
__shared__ fp4_t Q_fp4[BLOCK_SIZE];
__shared__ float Q_scale[NUM_GROUPS];
quantize_fp4(Q, Q_fp4, Q_scale);

// 步骤 2: 分块处理 KV
float acc[HEAD_DIM] = {0};
float max_val = -INF;

for (int kv_block = 0; kv_block < num_kv_blocks; kv_block++) {
// 加载并量化 K
__shared__ fp4_t K_fp4[BLOCK_SIZE];
__shared__ float K_scale[NUM_GROUPS];
quantize_fp4(K + kv_block * BLOCK_SIZE, K_fp4, K_scale);

// FP4 Tensor Core 矩阵乘法
// Q @ K^T 使用 FP4 TC
float scores[BLOCK_SIZE];
mma_sync_fp4(Q_fp4, K_fp4, scores);

// 反量化并应用缩放
for (int i = 0; i < BLOCK_SIZE; i++) {
scores[i] *= Q_scale[i / GROUP_SIZE] * K_scale[i / GROUP_SIZE];
}

// Online Softmax
float old_max = max_val;
max_val = max(max_val, max(scores));
float exp_sum = 0;

for (int i = 0; i < BLOCK_SIZE; i++) {
float exp_val = exp(scores[i] - max_val);
exp_sum += exp_val;
acc[i] = acc[i] * exp(old_max - max_val) + exp_val * V[i];
}
}

// 归一化输出
for (int i = 0; i < HEAD_DIM; i++) {
Output[i] = acc[i] / exp_sum;
}
}

8-bit 训练 Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Int8TrainingAttention(nn.Module):
"""
8-bit 训练 Attention

前向和反向传播均使用 8-bit 精度
"""

def __init__(self, hidden_dim, num_heads):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads

# QKV 投影
self.qkv_proj = nn.Linear(hidden_dim, hidden_dim * 3)
self.out_proj = nn.Linear(hidden_dim, hidden_dim)

# 量化器
self.quantizer = Int8Quantizer()

def forward(self, x):
# 前向:8-bit 量化
qkv = self.qkv_proj(x)
qkv_int8 = self.quantizer.quantize(qkv)

# 8-bit attention
attn_out = int8_attention(qkv_int8)

# 输出投影
out = self.out_proj(attn_out)

return out

def backward(self, grad_output):
# 反向:保持 8-bit 梯度
grad_qkv = self.quantizer.quantize(grad_output)
return grad_qkv

实验结果

推理性能

RTX 5090 吞吐量

方法 TOPS 相对加速
cuBLAS FP16 180 1.0x
FlashAttention-2 210 1.2x
SageAttention3 FP4 1038 5.0x

精度评估

语言建模困惑度

模型 精度 WikiText2 PTB
LLaMA-7B FP16 15.82 28.45
LLaMA-7B FP4 16.25 29.12
LLaMA-7B FP4+修正 15.89 28.58

微调任务准确率

任务 FP16 8-bit 训练 差异
SST-2 94.2% 94.0% -0.2%
QNLI 82.5% 82.1% -0.4%
RTE 68.3% 67.9% -0.4%

结论:8-bit 微调可实现无损性能

预训练收敛

1
2
3
4
5
6
7
8
9
10
11
12
13
预训练 Loss 曲线对比:

Loss

│ FP16: ──────────

│ 8-bit: ─────╲ (收敛慢~15%)

│ FP4: ───╲ (收敛慢~25%)

└─────────────────────
0 10k 20k 30k
训练步数

建议:预训练仍使用 FP16/BF16


部署指南

硬件要求

组件 最低要求 推荐
GPU RTX 5090 B100/B200
CUDA 12.8+ 13.0+
显存 16GB 32GB+

集成步骤

1
2
3
4
5
6
7
8
9
# 安装
pip install sageattention

# 使用
from sageattention import sageattn

# 替换标准 attention
output = sageattn(q, k, v, precision='fp4') # 推理
output = sageattn(q, k, v, precision='int8') # 微调

总结

SageAttention3 充分利用 Blackwell FP4 Tensor Core 实现 5 倍推理加速:

核心贡献

  1. FP4 微缩放量化保持精度
  2. 硬件感知优化最大化吞吐
  3. 8-bit 训练探索

实际价值

  • 1038 TOPS 推理性能
  • 微调任务无损加速
  • 即插即用集成

评分: 4.2/5.0 ⭐⭐⭐⭐

推荐度: 推荐。Blackwell 用户的推理加速首选。

© 2026 Generative AI Discovery All Rights Reserved.
Theme by hiero