论文信息
- 标题: Chain of Preference Optimization: Improving Chain-of-Thought Reasoning in LLMs
- 作者: Xuan Zhang, Chao Du, Tianyu Pang, Qian Liu, Wei Gao, Min Lin
- 机构: Sea AI Lab (SAIL), Nanyang Technological University
- 发表: NeurIPS 2024
- 链接: arXiv | GitHub | PDF
核心贡献
CPO通过偏好优化将Tree-of-Thought的搜索能力蒸馏到Chain-of-Thought推理中,实现了在推理时无需树搜索开销的情况下,达到甚至超越ToT的性能。核心创新在于利用树搜索过程中的隐含偏好信息,训练模型对齐优质推理路径。
研究动机
CoT的局限性
Chain-of-Thought prompting要求模型生成逐步推理过程,这种方法的效果高度依赖于:
- 单一路径依赖:模型一旦在某步推理出错,后续步骤会沿着错误方向继续
- 贪心解码的次优性:标准CoT使用贪心或采样解码,容易陷入局部最优
- 缺乏自我纠错能力:模型难以在推理过程中回溯和修正错误
ToT的计算困境
Tree-of-Thought通过搜索解决了CoT的问题,但计算成本难以承受:
- 推理时开销:每个问题需要生成和评估数十到上百条候选路径
- 延迟问题:在线服务场景下,树搜索导致响应时间过长
- 资源消耗:API调用成本成倍增加(GPT-4每次调用0.03-0.06美元)
CPO的核心洞察
作者提出:能否在训练时使用ToT的搜索结果,让模型学会在推理时直接生成高质量路径?
这个想法的合理性在于:
- ToT搜索本质上是在探索推理空间,标注了哪些路径更优
- 这种偏好信息可以用于监督学习
- 如果模型学会了隐式搜索,就无需显式搜索
方法详解
整体框架
CPO采用两阶段流程:
阶段1:ToT数据收集
- 对训练集中的每个问题执行ToT树搜索
- 记录搜索过程中的所有推理路径及其评分
- 构建偏好对:(优质路径, 劣质路径)
阶段2:DPO偏好优化
- 使用收集的偏好数据微调基础模型
- 优化目标:提升优质路径的生成概率,降低劣质路径的概率
- 微调后的模型可以直接用CoT prompting,无需搜索
偏好数据构建
从ToT搜索树中提取偏好对的核心在于利用评估函数的打分:
- 正样本(Chosen):搜索过程中评分最高的推理路径
- 负样本(Rejected):同一问题下评分较低的推理路径
关键细节:
- 步级别偏好:不仅比较完整路径,还对齐每个推理步骤
- 多样性保证:每个问题收集多对偏好数据,覆盖不同的错误模式
- 质量过滤:仅保留评分差距明显的偏好对(避免噪声标注)
DPO训练目标
Direct Preference Optimization直接优化模型的生成策略:
1 | L_DPO = -E[(log σ(β log(π_θ(y_w|x)/π_ref(y_w|x)) - β log(π_θ(y_l|x)/π_ref(y_l|x))))] |
其中:
π_θ:当前训练的模型策略π_ref:参考模型(通常是训练前的初始模型)y_w:优质推理路径(winner)y_l:劣质推理路径(loser)β:温度系数,控制优化强度
直观理解:
- 最大化优质路径相对于参考模型的对数概率比
- 最小化劣质路径的相对概率
- β控制偏离参考模型的程度(防止过拟合)
实验结果
主要性能提升
论文在三类推理任务上验证了CPO的有效性:
1. 算术推理(Game of 24)
- 基础CoT准确率:~45%
- ToT (BFS)准确率:~74%
- CPO准确率:~76% ✓ 超越ToT
- 推理成本:CPO仅为ToT的1/20(无需搜索开销)
2. 常识推理(CSQA - CommonsenseQA)
- 基础CoT:68.2%
- ToT:71.5%
- CPO:72.1% ✓ 超越ToT
3. 事实验证(HotpotQA)
- 基础CoT:61.3%
- ToT:64.8%
- CPO:65.2% ✓ 超越ToT
关键发现
发现1:步级别对齐至关重要
消融实验对比:
- 仅对齐最终答案:准确率提升3.2%
- 对齐每个推理步骤:准确率提升8.7% ✓
直观解释:CoT推理是一个序列决策过程,早期步骤的质量会影响后续所有步骤。步级别DPO确保模型在每一步都学会选择更优的推理方向。
发现2:偏好对质量>数量
固定训练样本数,对比不同min_score_gap:
- gap=0.1(弱偏好):+4.3%
- gap=0.3(中等偏好):+8.7% ✓
- gap=0.5(强偏好,数据减少40%):+7.1%
效率对比
推理成本分析(以100个问题为例):
| 方法 | API调用次数 | 总Token数 | 相对成本 | 延迟 |
|---|---|---|---|---|
| CoT | 100 | ~50K | 1x | 1x |
| ToT (b=5, d=4) | 2,500 | ~1.2M | 24x | 20x |
| CPO | 100 | ~50K | 1x ✓ | 1x ✓ |
关键优势:CPO在推理时与CoT完全相同的成本下,实现了ToT级别的性能。
实用价值分析
适用场景
CPO特别适合以下情况:
✓ 推荐使用:
- 生产环境部署:需要低延迟、可控成本的推理任务
- 中等难度推理:2-5步推理问题(数学、逻辑、代码调试)
- 有标注数据:能够收集ToT搜索结果或人工偏好标注
- 模型可微调:有足够算力进行DPO训练(7B模型需1-2张A100)
✗ 不推荐:
- 极端复杂推理:需要>10步推理或大量回溯的问题(此时ToT仍不可替代)
- 零样本场景:无法收集偏好数据时
- 快速原型:没有训练资源时,直接用ToT或CoT
实践建议
数据收集策略:
1 | # 高效收集ToT数据的技巧 |
与其他技术结合
1 | # CPO + Self-Consistency:进一步提升鲁棒性 |
代码实现
快速开始
环境配置:
1 | # 安装依赖 |
数据收集阶段:
1 | # 在训练集上执行ToT搜索(以Game of 24为例) |
DPO训练:
1 | # 微调模型 |
总结
Chain of Preference Optimization是prompt engineering领域的重要进展,它优雅地解决了ToT的计算困境:
核心贡献:
- ✓ 将树搜索的质量提升与CoT的效率优势结合
- ✓ 首次系统性地用偏好学习蒸馏复杂推理能力
- ✓ 提供了完整的开源实现和复现指南
技术亮点:
- 步级别DPO对齐:不仅优化最终答案,更优化推理过程
- 质量感知的偏好构建:通过min_score_gap过滤噪声
- 工程友好:基于成熟的DPO框架,易于复现和扩展
实用价值:
- 在NeurIPS 2024被接收,学术认可度高
- 已有多个follow-up工作(如用于代码生成、数学证明)
- 特别适合需要部署的生产环境
CPO开启了「推理蒸馏」的新范式。核心思想:任何需要多次交互/搜索的能力,都可能通过偏好学习蒸馏到单步生成中。
参考文献
1 | @inproceedings{zhang2024cpo, |