GTPO: 用梯度冲突修正和熵控制稳定 GRPO 训练
ArXiv ID : 2508.03772 作者 : Marco Simoni, Aleksandar Fontana, Giulio Rossolini, Andrea Saracino, Paolo Mori机构 : IIT-CNR (Italian National Research Council)发布日期 : 2025-08-05
摘要 GRPO(Group Relative Policy Optimization)在 LLM 对齐训练中越来越流行,但存在两个严重稳定性问题:Token 级惩罚导致梯度冲突 和策略崩溃 。本文提出的 GTPO (Gradient-corrected and Threshold-filtered Policy Optimization)通过冲突感知梯度修正和熵阈值过滤解决这些问题。在数学推理基准上,GTPO 全面超越 GRPO,在 GSM8K、MATH、AIME 2024/2025 上取得 SOTA 结果。
问题背景 GRPO 的普及与局限 GRPO 的成功 :
DeepSeek-R1 使用 GRPO 实现推理能力突破
无需价值网络,简化训练流程
在数学推理任务上表现优异
但 GRPO 存在严重稳定性问题 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 GRPO 训练中的梯度冲突问题: 场景:同一个有价值的 token "x^2" 出现在不同质量的回答中 回答 A(高奖励): "设 x = 5,则 x^2 = 25" ← "x^2" 应该被鼓励 回答 B(低奖励): "设 x^2 = 5,则 x = √5" ← "x^2" 应该被抑制 GRPO 的困境: - 同一个 token 收到相反的梯度信号 - 导致训练震荡和收敛困难 - 特别是在补全的初始和结尾部分
策略崩溃现象 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 策略崩溃过程: 训练初期: - 模型输出多样,探索充分 - 熵值正常 训练中期(问题出现): - 某些负奖励样本包含高置信度决策 - 这些决策被错误惩罚 - 模型开始避免高置信度输出 训练后期(崩溃): - 策略趋向均匀分布 - 输出质量大幅下降 - KL 散度未能及时预警
KL 散度的局限性 :
反应滞后(平均 200-300 步后才显著上升)
对早期崩溃不敏感
需要参考模型,增加计算成本
GTPO 方法 整体架构 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 ┌─────────────────────────────────────────────────────────┐ │ GTPO Framework │ │ │ │ ┌─────────────────┐ ┌─────────────────┐ │ │ │ 梯度冲突修正 │ │ 熵阈值过滤 │ │ │ │ │ │ │ │ │ │ • Token 级分析 │ │ • 平均熵监控 │ │ │ │ • 冲突检测 │ │ • 动态阈值 │ │ │ │ • 选择性更新 │ │ • 样本过滤 │ │ │ └─────────────────┘ └─────────────────┘ │ │ │ │ │ │ └───────────┬───────────┘ │ │ │ │ │ ▼ │ │ ┌─────────────────┐ │ │ │ GRPO 优化器 │ │ │ └─────────────────┘ │ └─────────────────────────────────────────────────────────┘
组件 1:冲突感知梯度修正 核心思想 :识别并修正冲突的梯度信号
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 def gradient_conflict_correction (gradients, token_importance, threshold=0.5 ): """ 梯度冲突修正 Args: gradients: 原始梯度 [batch, seq_len, vocab] token_importance: Token 重要性分数 [batch, seq_len] threshold: 冲突检测阈值 """ corrected_gradients = gradients.clone() for batch_idx in range (gradients.shape[0 ]): for seq_idx in range (gradients.shape[1 ]): conflict_score = compute_conflict_score( gradients[batch_idx, seq_idx] ) if conflict_score > threshold and token_importance[batch_idx, seq_idx] > 0.7 : corrected_gradients[batch_idx, seq_idx] = torch.relu(gradients[batch_idx, seq_idx]) elif token_importance[batch_idx, seq_idx] > 0.8 : corrected_gradients[batch_idx, seq_idx] *= 1.5 return corrected_gradients def compute_conflict_score (gradient ): """ 计算梯度冲突分数 冲突分数 = |负梯度 | / (|正梯度 | + |负梯度 |) 值越大表示冲突越严重 """ pos_norm = torch.norm(torch.relu(gradient)) neg_norm = torch.norm(torch.relu(-gradient)) if pos_norm + neg_norm == 0 : return 0 return neg_norm / (pos_norm + neg_norm)
Token 重要性估计 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 def estimate_token_importance (model, input_ids, token_positions ): """ 估计 Token 重要性 基于: 1. 注意力权重 2. 梯度范数 3. 信息量(罕见程度) """ attn_weights = get_attention_weights(model, input_ids) token_gradients = compute_token_gradients(model, input_ids) token_frequencies = get_token_frequencies(model.vocab) info_content = -torch.log(token_frequencies + 1e-6 ) importance = ( 0.4 * attn_weights + 0.4 * token_gradients + 0.2 * info_content ) return importance
组件 2:熵阈值过滤 为什么选择熵而非 KL 散度?
1 2 3 4 5 6 7 8 9 10 11 def entropy_vs_kl_comparison (): """ 熵监控 vs KL 散度监控 指标 | 反应速度 | 计算成本 | 需要参考模型 | 检测准确率 --------------|----------|----------|--------------|------------ KL 散度 | 慢 | 高 | 是 | 75% 平均熵 | 快 | 低 | 否 | 92% 最大熵 | 中 | 低 | 否 | 85% """ pass
熵阈值计算 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 def compute_entropy_threshold (model, calibration_data, percentile=95 ): """ 计算熵阈值 使用校准集上熵分布的分位数作为阈值 """ entropies = [] for sample in calibration_data: outputs = model.generate(sample['input_ids' ], num_return_sequences=5 ) probs = model.get_output_probs(outputs) entropy = -torch.mean(torch.sum (probs * torch.log(probs + 1e-6 ), dim=-1 )) entropies.append(entropy) threshold = torch.quantile(torch.tensor(entropies), percentile / 100 ) return threshold def filter_by_entropy (outputs, entropies, threshold ): """ 根据熵阈值过滤样本 丢弃熵过高的样本(可能是策略崩溃前兆) """ mask = entropies < threshold filtered_outputs = outputs[mask] filtered_entropies = entropies[mask] return filtered_outputs, filtered_entropies
GTPO 训练算法 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 def gtpo_training (model, dataset, num_epochs, lr=1e-5 ): """ GTPO 训练主循环 """ entropy_threshold = compute_entropy_threshold(model, calibration_data) optimizer = torch.optim.AdamW(model.parameters(), lr=lr) for epoch in range (num_epochs): for batch in dataset: outputs, log_probs = generate_group_outputs( model, batch['input_ids' ], group_size=8 ) rewards = compute_rewards(outputs, batch['ground_truth' ]) entropies = compute_entropy(outputs, log_probs) filtered_outputs, filtered_rewards = filter_by_entropy( outputs, rewards, entropy_threshold ) advantages = compute_grpo_advantages(filtered_rewards) loss = compute_gtpo_loss(filtered_outputs, advantages) gradients = torch.autograd.grad(loss, model.parameters()) token_importance = estimate_token_importance( model, batch['input_ids' ] ) corrected_gradients = gradient_conflict_correction( gradients, token_importance ) optimizer.zero_grad() for param, grad in zip (model.parameters(), corrected_gradients): param.grad = grad optimizer.step() if epoch % 10 == 0 : entropy_threshold = update_entropy_threshold( entropy_threshold, entropies )
实验结果 实验设置 基准任务 :
GSM8K:小学数学题
MATH:高中数学竞赛
AIME 2024/2025:数学奥林匹克
AMC 2023:美国数学竞赛
基线方法 :
模型 :
DeepSeek-Math-7B
LLaMA-2-7B
主要结果 GSM8K 数学推理
方法
准确率
提升
训练稳定性
GRPO
78.2%
-
中
PPO
81.5%
+3.3%
高
DPO
76.8%
-1.4%
高
GSPO
82.1%
+3.9%
高
GTPO
84.3%
+6.1%
高
MATH 竞赛题
方法
简单
中等
困难
平均
GRPO
52.3%
38.1%
22.5%
37.6%
PPO
55.1%
41.2%
25.8%
40.7%
GSPO
56.8%
42.5%
27.1%
42.1%
GTPO
59.2%
45.3%
29.8%
44.8%
AIME 2024
方法
准确率
通过题数
GRPO
45.0%
6.75/15
PPO
48.3%
7.25/15
GSPO
50.0%
7.50/15
GTPO
53.3%
8.00/15
消融实验 梯度修正有效性
配置
GSM8K
MATH
AIME
GTPO (完整)
84.3%
44.8%
53.3%
- 梯度修正
81.5%
42.1%
48.7%
- 熵过滤
82.1%
42.8%
50.0%
- 两者
78.2%
37.6%
45.0%
结论 :两个组件各有贡献,协同效果最佳
熵阈值敏感性
阈值分位数
GSM8K
训练稳定性
崩溃率
80%
82.1%
高
5%
90%
83.8%
高
3%
95%
84.3%
高
2%
99%
83.5%
中
8%
推荐 :95 分位数是最佳平衡点
训练稳定性分析 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 训练过程中的熵变化: 熵值 │ │ GRPO: ──────●──── 崩溃点 │ ╱ │ ╱ │ ● GTPO: 稳定 │ ╱ │ ╱ └───────────────────────── 0 500 1000 1500 训练步数 GTPO 有效抑制了熵的异常增长
与 GSPO 对比
维度
GSPO
GTPO
优化层级
序列级
Token 级
计算开销
低
中
实现复杂度
低
中
适用场景
通用对齐
推理任务
与 GTPO 关系
互补
互补
联合使用效果 :
1 2 3 GSPO + GTPO: - GSM8K: 85.1% (比单独 GTPO +0.8%) - MATH: 45.5% (比单独 GTPO +0.7%)
实践指南 代码实现 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 from transformers import AutoModelForCausalLMfrom gtpo import GTPOTrainermodel = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-Math-7b" ) trainer = GTPOTrainer( model=model, train_dataset=train_data, eval_dataset=eval_data, group_size=8 , entropy_percentile=95 , gradient_correction=True , learning_rate=1e-5 , num_epochs=3 ) trainer.train() results = trainer.evaluate() print (f"GSM8K Accuracy: {results['gsm8k' ]:.2 %} " )
超参数推荐
参数
推荐值
说明
group_size
8
每组输出数量
entropy_percentile
95
熵阈值分位数
learning_rate
1e-5
学习率
gradient_threshold
0.5
梯度冲突阈值
importance_threshold
0.7
Token 重要性阈值
训练成本
规模
数据量
训练时间
GPU
7B
10K
4 小时
8×A100
13B
20K
8 小时
8×A100
70B
50K
24 小时
16×A100
总结 GTPO 通过梯度冲突修正和熵阈值过滤,有效解决了 GRPO 训练的稳定性问题:
核心贡献 :
发现并分析了 GRPO 中的梯度冲突问题
提出熵监控替代 KL 散度的新方法
在数学推理基准上取得 SOTA 结果
开源实现,易于集成
实际价值 :
GRPO 训练不稳定时的轻量级修复方案
无需参考模型,降低训练成本
迁移成本低,易于应用
资源
评分 : 4.0/5.0 ⭐⭐⭐⭐
推荐度 : 推荐。GRPO 稳定性问题的有效解决方案,特别适合数学推理任务。