MOSS: 用微缩放和自动缩放实现无损 FP8 训练,OLMo-7B 加速 34%

MOSS: 用微缩放和自动缩放实现无损 FP8 训练,OLMo-7B 加速 34%

ArXiv ID: 2511.05811
作者: Yu Zhang, Hui-Ling Zhen, Mingxuan Yuan, Bei Yu
机构: The Chinese University of Hong Kong, Huawei Noah’s Ark Lab
发布日期: 2025-11-08


FP8 训练的理想与现实

FP8 训练的潜力

1
2
3
4
5
6
7
8
9
10
FP8 vs BF16 理论对比:

指标 | BF16 | FP8 | 提升
--------------|-------|-------|------
计算密度 | 128 | 256 | 2x
内存占用 | 100% | 50% | 50%↓
通信开销 | 100% | 25-50%| 50-75%↓
能效比 | 1x | 1.8x | 80%↑

理论加速:2x 计算 + 2x 内存带宽 = 4x 潜力

现实挑战

问题 1:动态范围不足

1
2
3
4
5
6
7
8
9
FP8 数值范围对比:

BF16: [~10^±38] (宽动态范围)
FP8: [~10^±15] (窄动态范围)

结果:
- 大值溢出 → Inf
- 小值下溢 → 0
- 训练不稳定 → 崩溃

问题 2:量化开销

1
2
3
4
5
6
7
8
9
10
11
12
现有方案的性能损失:

COAT/NVIDIA TE 的 per-group 量化:
┌─────────────────────────────────────┐
│ GEMM Main Loop: │
│ • Tensor Core 计算 │
│ • + 每次反量化开销 │
│ (查表 + 乘法) │
│ │
│ 结果:计算加速被反量化抵消 │
│ 净加速:仅 15-20% │
└─────────────────────────────────────┘

MOSS 核心创新

创新 1:两级微缩放策略

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
MOSS 两级缩放架构:

┌─────────────────────────────────────────────────────────┐
│ Two-Level Micro-Scaling │
│ │
│ 原始激活值 (BF16) │
│ ↓ │
│ ┌───────────────────┐ │
│ │ Level 1: │ │
│ │ Per-tensor Scale │ ← 全局粗调 (FP32) │
│ │ (全局缩放因子) │ │
│ └───────────────────┘ │
│ ↓ │
│ ┌───────────────────┐ │
│ │ Level 2: │ │
│ │ Micro Scale │ ← 局部微调 (power-of-two) │
│ │ (微缩放因子) │ │
│ └───────────────────┘ │
│ ↓ │
│ FP8 量化值 │
└─────────────────────────────────────────────────────────┘

Power-of-Two 的关键优势

1
2
3
4
5
6
7
8
9
10
11
12
13
# 传统缩放:需要乘法反量化
scale = 1.73
value_fp8 = 50
value_bf16 = value_fp8 * scale # 乘法操作

# Power-of-Two 缩放:只需位移
scale = 2^k # k 为整数
value_bf16 = value_fp8 << k # 位移操作,零开销

# 示例
scale = 2^3 = 8
value_fp8 = 50
value_bf16 = 50 << 3 = 400 # 等同于 50 * 8

两级缩放实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch

class TwoLevelMicroScaling:
"""
两级微缩放量化器
"""

def __init__(self, num_micro_groups=8):
self.num_micro_groups = num_micro_groups
# Level 1: 全局缩放因子 (FP32)
self.tensor_scale = None
# Level 2: 微缩放因子表 (power-of-two)
self.micro_scales = None

def quantize(self, tensor: torch.Tensor) -> torch.Tensor:
"""
两级量化

步骤:
1. 计算全局缩放因子
2. 应用全局缩放
3. 分組计算微缩放因子
4. 应用微缩放并量化到 FP8
"""
# Level 1: 全局缩放
max_val = tensor.abs().max()
self.tensor_scale = max_val / 448 # FP8 最大可表示值

# 应用全局缩放
scaled_tensor = tensor / self.tensor_scale

# Level 2: 分组微缩放
# 将 tensor 分成 num_micro_groups 组
groups = scaled_tensor.chunk(self.num_micro_groups, dim=0)

quantized_groups = []
self.micro_scales = []

for group in groups:
# 计算组内最优 power-of-two 缩放因子
group_max = group.abs().max()

# 找到最接近的 power-of-two
k = torch.round(torch.log2(group_max / 128)).int()
micro_scale = 2.0 ** k.item()

self.micro_scales.append(micro_scale)

# 应用微缩放并量化
micro_scaled = group / micro_scale
quantized = micro_scaled.clamp(-128, 127).to(torch.int8)
quantized_groups.append(quantized)

return torch.cat(quantized_groups, dim=0)

def dequantize(self, quantized_tensor: torch.Tensor) -> torch.Tensor:
"""
反量化

关键:微缩放因子的反量化只需位移
"""
groups = quantized_tensor.chunk(self.num_micro_groups, dim=0)
dequantized_groups = []

for i, group in enumerate(groups):
# 恢复微缩放 (位移操作)
micro_scale = self.micro_scales[i]
dequantized = group.to(torch.float32) * micro_scale

dequantized_groups.append(dequantized)

# 恢复全局缩放
dequantized_tensor = torch.cat(dequantized_groups, dim=0)
dequantized_tensor *= self.tensor_scale

return dequantized_tensor

GEMM Kernel 优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
MOSS GEMM 执行流程:

传统 per-group GEMM:
┌─────────────────────────────────────────┐
│ Main Loop (每个 block): │
│ for k in K: │
│ a = dequant(A[k]) ← 开销 │
│ b = dequant(B[k]) ← 开销 │
│ acc += matmul(a, b) │
└─────────────────────────────────────────┘

MOSS GEMM:
┌─────────────────────────────────────────┐
│ Main Loop (纯 Tensor Core): │
│ for k in K: │
│ acc += matmul(A[k], B[k]) ← 无开销 │
│ │
│ Epilogue (一次性反量化): │
│ output = dequant(acc) │
└─────────────────────────────────────────┘

性能提升:Main Loop 100% Tensor Core 利用

创新 2:权重自动缩放

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class WeightAutoScaling:
"""
权重自动缩放

核心洞察:
Adam 优化器的权重更新被学习率 bounded:
W_{t+1} = W_t - lr * gradient

因此:|W_{t+1}| <= |W_t| + lr * |gradient|

可以预测下一轮的缩放因子,无需 runtime 计算 max
"""

def __init__(self, lr=1e-3):
self.lr = lr
self.predicted_scale = None
self.current_scale = None

def predict_next_scale(self, current_weights: torch.Tensor,
gradient: torch.Tensor) -> float:
"""
预测下一轮缩放因子

基于 Adam 更新公式的上界估计
"""
# 估计最大可能变化
max_update = self.lr * gradient.abs().max()

# 预测新的最大值
current_max = current_weights.abs().max()
predicted_max = current_max + max_update

# 预测缩放因子
self.predicted_scale = predicted_max / 127.0

return self.predicted_scale

def quantize_weight(self, weight: torch.Tensor,
gradient: torch.Tensor) -> torch.Tensor:
"""
权重量化(无需 runtime max 计算)
"""
# 使用预测的缩放因子
if self.predicted_scale is None:
# 第一次:计算实际 max
scale = weight.abs().max() / 127.0
else:
# 后续轮次:使用预测值
scale = self.predicted_scale

# 量化
quantized = (weight / scale).clamp(-128, 127).to(torch.int8)

# 为下一轮预测
self.current_scale = scale
self.predict_next_scale(weight, gradient)

return quantized

实验结果

训练加速

框架 精度 相对 BF16 加速 模型质量
BF16 基线 BF16 - 基准
NVIDIA TE FP8 ~15% 轻微下降
COAT FP8 ~20% 无损
MOSS FP8 34% 无损

关键:MOSS 相比 COAT 快12.3%

训练稳定性

1
2
3
4
5
6
7
8
9
10
11
训练 Loss 曲线对比:

Loss

│ BF16: ──────────
│ MOSS: ────────── (与 BF16 重合)
│ COAT: ────╲ (轻微波动)
│ TE: ──╲ (明显偏差)
└─────────────────────
0 1k 2k 3k
训练步数

下游任务评估

任务 BF16 MOSS 差异
MMLU 35.2% 35.1% -0.1%
HellaSwag 62.5% 62.3% -0.2%
GSM8K 12.3% 12.1% -0.2%
HumanEval 18.5% 18.3% -0.2%

结论:无统计显著差异


技术对比

特性 NVIDIA TE COAT MOSS
激活量化 Per-tensor Per-group 两级微缩放
权重量化 Per-tensor Per-tensor 自动缩放
GEMM 反量化位置 Epilogue Main Loop Epilogue
Runtime 缩放计算 需要 需要 不需要(权重)
加速(vs BF16) ~15% ~20% 34%
精度损失

总结

MOSS 通过两级微缩放和自动缩放,同时优化了 FP8 训练的精度和效率:

核心贡献

  1. 两级微缩放实现接近 per-group 精度
  2. Power-of-two 缩放实现零开销反量化
  3. 权重自动缩放消除 runtime max 计算

实际价值

  • 34% 训练加速
  • 无损模型质量
  • 适用于大规模预训练

评分: 4.0/5.0 ⭐⭐⭐⭐

推荐度: 推荐。FP8 训练的高效解决方案。

© 2026 Generative AI Discovery All Rights Reserved.
Theme by hiero