Chain of Preference Optimization: 用偏好学习蒸馏Tree-of-Thought推理能力

论文信息

  • 标题: Chain of Preference Optimization: Improving Chain-of-Thought Reasoning in LLMs
  • 作者: Xuan Zhang, Chao Du, Tianyu Pang, Qian Liu, Wei Gao, Min Lin
  • 机构: Sea AI Lab (SAIL), Nanyang Technological University
  • 发表: NeurIPS 2024
  • 链接: arXiv | GitHub | PDF

核心贡献

CPO通过偏好优化将Tree-of-Thought的搜索能力蒸馏到Chain-of-Thought推理中,实现了在推理时无需树搜索开销的情况下,达到甚至超越ToT的性能。核心创新在于利用树搜索过程中的隐含偏好信息,训练模型对齐优质推理路径。

研究动机

CoT的局限性

Chain-of-Thought prompting要求模型生成逐步推理过程,这种方法的效果高度依赖于:

  1. 单一路径依赖:模型一旦在某步推理出错,后续步骤会沿着错误方向继续
  2. 贪心解码的次优性:标准CoT使用贪心或采样解码,容易陷入局部最优
  3. 缺乏自我纠错能力:模型难以在推理过程中回溯和修正错误

ToT的计算困境

Tree-of-Thought通过搜索解决了CoT的问题,但计算成本难以承受:

  • 推理时开销:每个问题需要生成和评估数十到上百条候选路径
  • 延迟问题:在线服务场景下,树搜索导致响应时间过长
  • 资源消耗:API调用成本成倍增加(GPT-4每次调用0.03-0.06美元)

CPO的核心洞察

作者提出:能否在训练时使用ToT的搜索结果,让模型学会在推理时直接生成高质量路径?

这个想法的合理性在于:

  1. ToT搜索本质上是在探索推理空间,标注了哪些路径更优
  2. 这种偏好信息可以用于监督学习
  3. 如果模型学会了隐式搜索,就无需显式搜索

方法详解

整体框架

CPO采用两阶段流程:

阶段1:ToT数据收集

  • 对训练集中的每个问题执行ToT树搜索
  • 记录搜索过程中的所有推理路径及其评分
  • 构建偏好对:(优质路径, 劣质路径)

阶段2:DPO偏好优化

  • 使用收集的偏好数据微调基础模型
  • 优化目标:提升优质路径的生成概率,降低劣质路径的概率
  • 微调后的模型可以直接用CoT prompting,无需搜索

偏好数据构建

从ToT搜索树中提取偏好对的核心在于利用评估函数的打分

  1. 正样本(Chosen):搜索过程中评分最高的推理路径
  2. 负样本(Rejected):同一问题下评分较低的推理路径

关键细节:

  • 步级别偏好:不仅比较完整路径,还对齐每个推理步骤
  • 多样性保证:每个问题收集多对偏好数据,覆盖不同的错误模式
  • 质量过滤:仅保留评分差距明显的偏好对(避免噪声标注)

DPO训练目标

Direct Preference Optimization直接优化模型的生成策略:

1
L_DPO = -E[(log σ(β log(π_θ(y_w|x)/π_ref(y_w|x)) - β log(π_θ(y_l|x)/π_ref(y_l|x))))]

其中:

  • π_θ:当前训练的模型策略
  • π_ref:参考模型(通常是训练前的初始模型)
  • y_w:优质推理路径(winner)
  • y_l:劣质推理路径(loser)
  • β:温度系数,控制优化强度

直观理解

  • 最大化优质路径相对于参考模型的对数概率比
  • 最小化劣质路径的相对概率
  • β控制偏离参考模型的程度(防止过拟合)

实验结果

主要性能提升

论文在三类推理任务上验证了CPO的有效性:

1. 算术推理(Game of 24)

  • 基础CoT准确率:~45%
  • ToT (BFS)准确率:~74%
  • CPO准确率:~76% ✓ 超越ToT
  • 推理成本:CPO仅为ToT的1/20(无需搜索开销)

2. 常识推理(CSQA - CommonsenseQA)

  • 基础CoT:68.2%
  • ToT:71.5%
  • CPO:72.1% ✓ 超越ToT

3. 事实验证(HotpotQA)

  • 基础CoT:61.3%
  • ToT:64.8%
  • CPO:65.2% ✓ 超越ToT

关键发现

发现1:步级别对齐至关重要

消融实验对比:

  • 仅对齐最终答案:准确率提升3.2%
  • 对齐每个推理步骤:准确率提升8.7%

直观解释:CoT推理是一个序列决策过程,早期步骤的质量会影响后续所有步骤。步级别DPO确保模型在每一步都学会选择更优的推理方向。

发现2:偏好对质量>数量

固定训练样本数,对比不同min_score_gap:

  • gap=0.1(弱偏好):+4.3%
  • gap=0.3(中等偏好):+8.7% ✓
  • gap=0.5(强偏好,数据减少40%):+7.1%

效率对比

推理成本分析(以100个问题为例):

方法 API调用次数 总Token数 相对成本 延迟
CoT 100 ~50K 1x 1x
ToT (b=5, d=4) 2,500 ~1.2M 24x 20x
CPO 100 ~50K 1x 1x

关键优势:CPO在推理时与CoT完全相同的成本下,实现了ToT级别的性能。

实用价值分析

适用场景

CPO特别适合以下情况:

✓ 推荐使用

  1. 生产环境部署:需要低延迟、可控成本的推理任务
  2. 中等难度推理:2-5步推理问题(数学、逻辑、代码调试)
  3. 有标注数据:能够收集ToT搜索结果或人工偏好标注
  4. 模型可微调:有足够算力进行DPO训练(7B模型需1-2张A100)

✗ 不推荐

  1. 极端复杂推理:需要>10步推理或大量回溯的问题(此时ToT仍不可替代)
  2. 零样本场景:无法收集偏好数据时
  3. 快速原型:没有训练资源时,直接用ToT或CoT

实践建议

数据收集策略

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 高效收集ToT数据的技巧

# Tip 1: 使用更便宜的模型收集数据
# GPT-3.5收集偏好对 → 微调Llama(成本降低10x)
collect_tot_data(model="gpt-3.5-turbo") # 0.01$/问题
train_cpo(base_model="llama-2-7b") # 一次性成本

# Tip 2: 混合真实数据和合成数据
real_data = load_existing_dataset() # 真实标注
synthetic_data = generate_with_tot() # ToT合成
combined_data = merge_and_filter(real_data, synthetic_data)

# Tip 3: 主动学习选择最有价值的问题
uncertain_questions = select_high_uncertainty(model, question_pool)
tot_data = collect_tot_data(uncertain_questions) # 仅对难题执行ToT

与其他技术结合

1
2
3
4
5
6
7
8
9
10
11
# CPO + Self-Consistency:进一步提升鲁棒性
def cpo_with_self_consistency(model, question, n_samples=5):
# 生成多个推理路径
paths = [model.generate(question) for _ in range(n_samples)]

# 提取答案并投票
answers = [extract_answer(path) for path in paths]
final_answer = majority_vote(answers)

return final_answer
# 实验显示:CPO+SC比单独CPO再提升2-3个百分点

代码实现

官方仓库:github.com/sail-sg/CPO

快速开始

环境配置

1
2
3
# 安装依赖
pip install -r requirement.txt
# 主要依赖:transformers, trl, torch, datasets

数据收集阶段

1
2
3
4
5
6
7
8
# 在训练集上执行ToT搜索(以Game of 24为例)
python run_test.py \
--task game24 \
--model gpt-3.5-turbo \
--method bfs \
--branch_factor 5 \
--max_depth 4 \
--output_dir ./tot_data

DPO训练

1
2
3
4
5
6
7
8
9
10
11
12
# 微调模型
python dpo_training.py \
--base_model meta-llama/Llama-2-7b-hf \
--data_path ./preference_pairs.json \
--beta 0.1 \
--learning_rate 5e-6 \
--num_epochs 3 \
--output_dir ./cpo_model

# 硬件需求:
# - Llama-2-7B: 至少1张A100 (40GB) 或 4张RTX 4090
# - 训练时间:1000对偏好数据约需4-8小时

总结

Chain of Preference Optimization是prompt engineering领域的重要进展,它优雅地解决了ToT的计算困境:

核心贡献

  1. ✓ 将树搜索的质量提升与CoT的效率优势结合
  2. ✓ 首次系统性地用偏好学习蒸馏复杂推理能力
  3. ✓ 提供了完整的开源实现和复现指南

技术亮点

  • 步级别DPO对齐:不仅优化最终答案,更优化推理过程
  • 质量感知的偏好构建:通过min_score_gap过滤噪声
  • 工程友好:基于成熟的DPO框架,易于复现和扩展

实用价值

  • 在NeurIPS 2024被接收,学术认可度高
  • 已有多个follow-up工作(如用于代码生成、数学证明)
  • 特别适合需要部署的生产环境

CPO开启了「推理蒸馏」的新范式。核心思想:任何需要多次交互/搜索的能力,都可能通过偏好学习蒸馏到单步生成中。

参考文献

1
2
3
4
5
6
@inproceedings{zhang2024cpo,
title={Chain of Preference Optimization: Improving Chain-of-Thought Reasoning in LLMs},
author={Zhang, Xuan and Du, Chao and Pang, Tianyu and Liu, Qian and Gao, Wei and Lin, Min},
booktitle={NeurIPS},
year={2024}
}
© 2026 Generative AI Discovery All Rights Reserved.
Theme by hiero