GTPO: 用梯度冲突修正和熵控制稳定 GRPO 训练

GTPO: 用梯度冲突修正和熵控制稳定 GRPO 训练

ArXiv ID: 2508.03772
作者: Marco Simoni, Aleksandar Fontana, Giulio Rossolini, Andrea Saracino, Paolo Mori
机构: IIT-CNR (Italian National Research Council)
发布日期: 2025-08-05


摘要

GRPO(Group Relative Policy Optimization)在 LLM 对齐训练中越来越流行,但存在两个严重稳定性问题:Token 级惩罚导致梯度冲突策略崩溃。本文提出的 GTPO(Gradient-corrected and Threshold-filtered Policy Optimization)通过冲突感知梯度修正和熵阈值过滤解决这些问题。在数学推理基准上,GTPO 全面超越 GRPO,在 GSM8K、MATH、AIME 2024/2025 上取得 SOTA 结果。


问题背景

GRPO 的普及与局限

GRPO 的成功

  • DeepSeek-R1 使用 GRPO 实现推理能力突破
  • 无需价值网络,简化训练流程
  • 在数学推理任务上表现优异

但 GRPO 存在严重稳定性问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
GRPO 训练中的梯度冲突问题:

场景:同一个有价值的 token "x^2" 出现在不同质量的回答中

回答 A(高奖励):
"设 x = 5,则 x^2 = 25" ← "x^2" 应该被鼓励

回答 B(低奖励):
"设 x^2 = 5,则 x = √5" ← "x^2" 应该被抑制

GRPO 的困境:
- 同一个 token 收到相反的梯度信号
- 导致训练震荡和收敛困难
- 特别是在补全的初始和结尾部分

策略崩溃现象

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
策略崩溃过程:

训练初期:
- 模型输出多样,探索充分
- 熵值正常

训练中期(问题出现):
- 某些负奖励样本包含高置信度决策
- 这些决策被错误惩罚
- 模型开始避免高置信度输出

训练后期(崩溃):
- 策略趋向均匀分布
- 输出质量大幅下降
- KL 散度未能及时预警

KL 散度的局限性

  • 反应滞后(平均 200-300 步后才显著上升)
  • 对早期崩溃不敏感
  • 需要参考模型,增加计算成本

GTPO 方法

整体架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
┌─────────────────────────────────────────────────────────┐
│ GTPO Framework │
│ │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ 梯度冲突修正 │ │ 熵阈值过滤 │ │
│ │ │ │ │ │
│ │ • Token 级分析 │ │ • 平均熵监控 │ │
│ │ • 冲突检测 │ │ • 动态阈值 │ │
│ │ • 选择性更新 │ │ • 样本过滤 │ │
│ └─────────────────┘ └─────────────────┘ │
│ │ │ │
│ └───────────┬───────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ GRPO 优化器 │ │
│ └─────────────────┘ │
└─────────────────────────────────────────────────────────┘

组件 1:冲突感知梯度修正

核心思想:识别并修正冲突的梯度信号

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def gradient_conflict_correction(gradients, token_importance, threshold=0.5):
"""
梯度冲突修正

Args:
gradients: 原始梯度 [batch, seq_len, vocab]
token_importance: Token 重要性分数 [batch, seq_len]
threshold: 冲突检测阈值
"""
corrected_gradients = gradients.clone()

for batch_idx in range(gradients.shape[0]):
for seq_idx in range(gradients.shape[1]):
# 计算梯度冲突分数
conflict_score = compute_conflict_score(
gradients[batch_idx, seq_idx]
)

# 如果冲突严重且 token 重要,跳过负面更新
if conflict_score > threshold and token_importance[batch_idx, seq_idx] > 0.7:
# 只保留正向梯度
corrected_gradients[batch_idx, seq_idx] = torch.relu(gradients[batch_idx, seq_idx])

# 放大正面更新
elif token_importance[batch_idx, seq_idx] > 0.8:
corrected_gradients[batch_idx, seq_idx] *= 1.5

return corrected_gradients


def compute_conflict_score(gradient):
"""
计算梯度冲突分数

冲突分数 = |负梯度 | / (|正梯度 | + |负梯度 |)
值越大表示冲突越严重
"""
pos_norm = torch.norm(torch.relu(gradient))
neg_norm = torch.norm(torch.relu(-gradient))

if pos_norm + neg_norm == 0:
return 0

return neg_norm / (pos_norm + neg_norm)

Token 重要性估计

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def estimate_token_importance(model, input_ids, token_positions):
"""
估计 Token 重要性

基于:
1. 注意力权重
2. 梯度范数
3. 信息量(罕见程度)
"""
# 注意力权重
attn_weights = get_attention_weights(model, input_ids)

# 梯度范数
token_gradients = compute_token_gradients(model, input_ids)

# 信息量(逆频率)
token_frequencies = get_token_frequencies(model.vocab)
info_content = -torch.log(token_frequencies + 1e-6)

# 综合重要性
importance = (
0.4 * attn_weights +
0.4 * token_gradients +
0.2 * info_content
)

return importance

组件 2:熵阈值过滤

为什么选择熵而非 KL 散度?

1
2
3
4
5
6
7
8
9
10
11
def entropy_vs_kl_comparison():
"""
熵监控 vs KL 散度监控

指标 | 反应速度 | 计算成本 | 需要参考模型 | 检测准确率
--------------|----------|----------|--------------|------------
KL 散度 | 慢 | 高 | 是 | 75%
平均熵 | 快 | 低 | 否 | 92%
最大熵 | 中 | 低 | 否 | 85%
"""
pass

熵阈值计算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def compute_entropy_threshold(model, calibration_data, percentile=95):
"""
计算熵阈值

使用校准集上熵分布的分位数作为阈值
"""
entropies = []

for sample in calibration_data:
outputs = model.generate(sample['input_ids'], num_return_sequences=5)
probs = model.get_output_probs(outputs)

# 计算平均熵
entropy = -torch.mean(torch.sum(probs * torch.log(probs + 1e-6), dim=-1))
entropies.append(entropy)

# 返回指定分位数作为阈值
threshold = torch.quantile(torch.tensor(entropies), percentile / 100)

return threshold


def filter_by_entropy(outputs, entropies, threshold):
"""
根据熵阈值过滤样本

丢弃熵过高的样本(可能是策略崩溃前兆)
"""
mask = entropies < threshold
filtered_outputs = outputs[mask]
filtered_entropies = entropies[mask]

return filtered_outputs, filtered_entropies

GTPO 训练算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def gtpo_training(model, dataset, num_epochs, lr=1e-5):
"""
GTPO 训练主循环
"""
# 步骤 1: 计算熵阈值
entropy_threshold = compute_entropy_threshold(model, calibration_data)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

for epoch in range(num_epochs):
for batch in dataset:
# 步骤 2: 生成多组输出
outputs, log_probs = generate_group_outputs(
model, batch['input_ids'], group_size=8
)

# 步骤 3: 计算奖励
rewards = compute_rewards(outputs, batch['ground_truth'])

# 步骤 4: 计算熵并过滤
entropies = compute_entropy(outputs, log_probs)
filtered_outputs, filtered_rewards = filter_by_entropy(
outputs, rewards, entropy_threshold
)

# 步骤 5: 计算 GRPO 优势
advantages = compute_grpo_advantages(filtered_rewards)

# 步骤 6: 计算梯度并修正
loss = compute_gtpo_loss(filtered_outputs, advantages)
gradients = torch.autograd.grad(loss, model.parameters())

# 步骤 7: 梯度冲突修正
token_importance = estimate_token_importance(
model, batch['input_ids']
)
corrected_gradients = gradient_conflict_correction(
gradients, token_importance
)

# 步骤 8: 应用修正后的梯度
optimizer.zero_grad()
for param, grad in zip(model.parameters(), corrected_gradients):
param.grad = grad
optimizer.step()

# 步骤 9: 动态更新熵阈值
if epoch % 10 == 0:
entropy_threshold = update_entropy_threshold(
entropy_threshold, entropies
)

实验结果

实验设置

基准任务

  • GSM8K:小学数学题
  • MATH:高中数学竞赛
  • AIME 2024/2025:数学奥林匹克
  • AMC 2023:美国数学竞赛

基线方法

  • GRPO(原始)
  • PPO
  • DPO
  • GSPO

模型

  • DeepSeek-Math-7B
  • LLaMA-2-7B

主要结果

GSM8K 数学推理

方法 准确率 提升 训练稳定性
GRPO 78.2% -
PPO 81.5% +3.3%
DPO 76.8% -1.4%
GSPO 82.1% +3.9%
GTPO 84.3% +6.1%

MATH 竞赛题

方法 简单 中等 困难 平均
GRPO 52.3% 38.1% 22.5% 37.6%
PPO 55.1% 41.2% 25.8% 40.7%
GSPO 56.8% 42.5% 27.1% 42.1%
GTPO 59.2% 45.3% 29.8% 44.8%

AIME 2024

方法 准确率 通过题数
GRPO 45.0% 6.75/15
PPO 48.3% 7.25/15
GSPO 50.0% 7.50/15
GTPO 53.3% 8.00/15

消融实验

梯度修正有效性

配置 GSM8K MATH AIME
GTPO (完整) 84.3% 44.8% 53.3%
- 梯度修正 81.5% 42.1% 48.7%
- 熵过滤 82.1% 42.8% 50.0%
- 两者 78.2% 37.6% 45.0%

结论:两个组件各有贡献,协同效果最佳

熵阈值敏感性

阈值分位数 GSM8K 训练稳定性 崩溃率
80% 82.1% 5%
90% 83.8% 3%
95% 84.3% 2%
99% 83.5% 8%

推荐:95 分位数是最佳平衡点

训练稳定性分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
训练过程中的熵变化:

熵值

│ GRPO: ──────●──── 崩溃点
│ ╱
│ ╱
│ ● GTPO: 稳定
│ ╱
│ ╱
└─────────────────────────
0 500 1000 1500
训练步数

GTPO 有效抑制了熵的异常增长

与 GSPO 对比

维度 GSPO GTPO
优化层级 序列级 Token 级
计算开销
实现复杂度
适用场景 通用对齐 推理任务
与 GTPO 关系 互补 互补

联合使用效果

1
2
3
GSPO + GTPO:
- GSM8K: 85.1% (比单独 GTPO +0.8%)
- MATH: 45.5% (比单独 GTPO +0.7%)

实践指南

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from transformers import AutoModelForCausalLM
from gtpo import GTPOTrainer

# 加载模型
model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-Math-7b")

# 创建 GTPO 训练器
trainer = GTPOTrainer(
model=model,
train_dataset=train_data,
eval_dataset=eval_data,
group_size=8,
entropy_percentile=95,
gradient_correction=True,
learning_rate=1e-5,
num_epochs=3
)

# 开始训练
trainer.train()

# 评估结果
results = trainer.evaluate()
print(f"GSM8K Accuracy: {results['gsm8k']:.2%}")

超参数推荐

参数 推荐值 说明
group_size 8 每组输出数量
entropy_percentile 95 熵阈值分位数
learning_rate 1e-5 学习率
gradient_threshold 0.5 梯度冲突阈值
importance_threshold 0.7 Token 重要性阈值

训练成本

规模 数据量 训练时间 GPU
7B 10K 4 小时 8×A100
13B 20K 8 小时 8×A100
70B 50K 24 小时 16×A100

总结

GTPO 通过梯度冲突修正和熵阈值过滤,有效解决了 GRPO 训练的稳定性问题:

核心贡献

  1. 发现并分析了 GRPO 中的梯度冲突问题
  2. 提出熵监控替代 KL 散度的新方法
  3. 在数学推理基准上取得 SOTA 结果
  4. 开源实现,易于集成

实际价值

  • GRPO 训练不稳定时的轻量级修复方案
  • 无需参考模型,降低训练成本
  • 迁移成本低,易于应用

资源


评分: 4.0/5.0 ⭐⭐⭐⭐

推荐度: 推荐。GRPO 稳定性问题的有效解决方案,特别适合数学推理任务。

© 2026 Generative AI Discovery All Rights Reserved.
Theme by hiero