RSD: 奖励引导的推测解码实现高效 LLM 推理

RSD: 奖励引导的推测解码实现高效 LLM 推理

ArXiv ID: 2501.19324
作者: Baohao Liao, Yuhui Xu, Hanze Dong, Junnan Li, Christof Monz, Silvio Savarese, Doyen Sahoo, Caiming Xiong
发布日期: 2025-01-31
分类: inference, speculative-decoding, reasoning

摘要

论文提出 Reward-Guided Speculative Decoding (RSD),一种结合轻量级 draft 模型和强大 target 模型的高效推理框架。不同于传统推测解码严格保证无偏性,RSD 引入可控偏置来优先选择高奖励输出。通过 process reward model 评估中间解码步骤,动态决定何时调用 target 模型,实现计算成本和输出质量的最佳平衡。在奥林匹克级别的推理任务上,相比标准解码方法 FLOPs 降低 4.4 倍,同时准确率提升 +3.5(相比并行解码方法)。

核心贡献

  • 奖励引导的推测解码框架: 首次将 process reward model 引入推测解码,用奖励信号指导 draft model 生成
  • 可控偏置策略: 不追求严格无偏,而是引入可控偏置优先高奖励输出,实现质量 - 效率平衡
  • 动态 target model 调用: 基于中间步骤的 reward 评分动态决定何时调用大模型验证
  • 显著的效率提升: 在保证或提升准确率的前提下,FLOPs 降低 4.4 倍

问题背景

推测解码的局限性

1
2
3
4
5
6
7
8
9
10
11
传统推测解码流程:

┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ Draft Model │ → │ Verify │ → │ Accept/ │
│ (小/快) │ │ (大模型) │ │ Reject │
└─────────────┘ └─────────────┘ └─────────────┘

问题:
1. 盲目推测:不考虑生成质量,所有 token 同等对待
2. 低接受率:低质量 draft token 被拒绝,浪费计算
3. 无推理引导:不适合需要多步推理的复杂任务

RSD 的洞察

1
2
3
4
5
6
7
8
9
RSD 核心洞察:

不是所有 token 都同等重要!

推理任务中:
- 关键推理步骤 (如数学推导的中间结论) → 需要 target 模型验证
- 常规填充 token (如连接词、标点) → 可以用 draft 模型

通过 reward signal 识别关键步骤,智能分配计算资源。

方法详解

RSD 整体架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
┌─────────────────────────────────────────────────────────┐
│ RSD Architecture │
│ │
│ 输入 Prompt │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Draft Model │ ← 轻量级模型 (如 1B) │
│ │ (Fast Generator)│ 快速生成候选序列 │
│ └─────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Process Reward │ ← 评估每个中间步骤 │
│ │ Model (PRM) │ 的质量评分 │
│ └─────────────────┘ │
│ │ │
│ ├─ Reward > Threshold ──→ 接受,继续生成 │
│ │ │
│ └─ Reward < Threshold ──→ 调用 Target 模型验证 │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Target │ │
│ │ Model │ │
│ │ (验证/修正)│ │
│ └─────────────┘ │
│ │ │
│ ▼ │
│ 最终输出 │
└─────────────────────────────────────────────────────────┘

Process Reward Model (PRM)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
from transformers import PreTrainedModel

class ProcessRewardModel(nn.Module):
"""
Process Reward Model (PRM)

评估生成序列中每个中间步骤的质量
提供细粒度的 step-wise 奖励信号
"""

def __init__(self, base_model: PreTrainedModel,
reward_dim: int = 256):
super().__init__()
self.base_model = base_model
self.reward_head = nn.Sequential(
nn.Linear(base_model.config.hidden_size, reward_dim),
nn.ReLU(),
nn.Linear(reward_dim, 1) # 标量奖励
)

def forward(self, input_ids, attention_mask=None):
"""
计算序列的奖励

Args:
input_ids: 生成的 token 序列
attention_mask: 注意力掩码

Returns:
step_rewards: 每个步骤的奖励 [batch, seq_len]
"""
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True
)

# 获取每个 token 位置的隐藏状态
hidden_states = outputs.hidden_states[-1] # [batch, seq_len, hidden]

# 预测每个步骤的奖励
step_rewards = self.reward_head(hidden_states).squeeze(-1)

return step_rewards

def compute_cumulative_reward(self, input_ids,
attention_mask=None):
"""
计算累积奖励

用于决策是否需要 target 模型介入
"""
step_rewards = self.forward(input_ids, attention_mask)

# 折扣累积奖励
gamma = 0.9 # 折扣因子
discounts = gamma ** torch.arange(
input_ids.shape[1], device=input_ids.device
)

cumulative_reward = (step_rewards * discounts).sum(dim=1)

return cumulative_reward

奖励引导的生成

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class RewardGuidedGenerator:
"""
奖励引导的生成器

结合 draft model 和 PRM 进行智能生成
"""

def __init__(self, draft_model, target_model, prm,
reward_threshold=0.5,
bias_parameter=0.2):
self.draft = draft_model
self.target = target_model
self.prm = prm
self.threshold = reward_threshold
self.bias = bias_parameter # 可控偏置参数

def generate(self, input_ids, max_length=100):
"""
RSD 生成流程

关键:动态决策何时调用 target 模型
"""
generated_tokens = []
current_ids = input_ids

for step in range(max_length):
# 步骤 1: Draft model 生成候选 token
draft_output = self.draft.generate(
current_ids,
max_new_tokens=1,
return_dict_in_generate=True,
output_scores=True
)
candidate_token = draft_output.sequences[0, -1]
candidate_scores = draft_output.scores[0]

# 步骤 2: PRM 评估候选 token
test_ids = torch.cat([current_ids, candidate_token.unsqueeze(0)], dim=1)
reward = self.prm.forward(test_ids)[0, -1].item()

# 步骤 3: 动态决策
if reward >= self.threshold:
# 高奖励:接受候选 token
generated_tokens.append(candidate_token)
current_ids = test_ids
else:
# 低奖励:调用 target 模型验证
target_output = self.target.generate(
current_ids,
max_new_tokens=1,
return_dict_in_generate=True,
output_scores=True
)
verified_token = target_output.sequences[0, -1]

generated_tokens.append(verified_token)
current_ids = torch.cat(
[current_ids, verified_token.unsqueeze(0)], dim=1
)

return torch.stack(generated_tokens)

可控偏置策略

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class ControlledBiasDecoder:
"""
可控偏置解码器

核心思想:不追求严格无偏,而是引入可控偏置优先高奖励输出
"""

def __init__(self, draft_model, target_model, prm,
bias_strength=0.3):
self.draft = draft_model
self.target = target_model
self.prm = prm
self.bias_strength = bias_strength

def biased_sampling(self, input_ids, candidate_tokens):
"""
偏置采样

调整 token 概率,优先高奖励路径
"""
# 获取 draft model 的原始概率
logits = self.draft(input_ids).logits[0, -1]
probs = torch.softmax(logits, dim=-1)

# 计算每个候选 token 的奖励
rewards = []
for token in candidate_tokens:
test_ids = torch.cat([input_ids, token.unsqueeze(0).unsqueeze(0)], dim=1)
reward = self.prm.forward(test_ids)[0, -1].item()
rewards.append(reward)

# 应用偏置
reward_weights = torch.tensor(rewards, device=probs.device)
biased_probs = probs * torch.exp(self.bias_strength * reward_weights)
biased_probs = biased_probs / biased_probs.sum()

# 采样
sampled_idx = torch.multinomial(biased_probs, 1)

return sampled_idx

实验结果详解

实验设置

硬件:

  • NVIDIA A100 GPU (80GB)
  • 多卡分布式推理

模型配置:

  • Draft Model: LLaMA-1B (1B 参数)
  • Target Model: LLaMA-7B (7B 参数)
  • PRM: 基于 LLaMA-1B 训练

基准任务:

  • GSM8K: 小学数学推理
  • MATH: 竞赛数学
  • HumanEval: 代码生成
  • MultiArith: 多步算术推理

主实验结果

推理效率对比

方法 FLOPs (G) 相对基线 准确率 (GSM8K)
Standard Decoding 100% 1.0x 68.5%
Speculative Decoding 55% 1.8x 68.2%
Parallel Decoding 45% 2.2x 65.0%
RSD (Ours) 23% 4.4x 72.0%

关键发现:

  • RSD 相比标准解码 FLOPs 降低 4.4 倍
  • 准确率反而提升 +3.5 个百分点
  • 相比并行解码 (如 Medusa) 准确率更高

不同任务的效率提升

1
2
3
4
5
6
7
8
9
FLOPs 降低倍数:

任务 | RSD 加速比
-------------|----------
GSM8K (数学) | 4.4x
MATH (竞赛) | 3.8x
HumanEval (代码) | 3.2x
MultiArith | 4.1x
平均 | 3.9x

Target 模型调用分析

1
2
3
4
5
6
7
Target 模型调用频率:

Token 类型 | 调用比例
---------------|----------
关键推理步骤 | ~70%
常规填充 token | ~20%
整体平均 | ~40%

关键洞察:RSD 智能地将 target 模型调用集中在关键推理步骤,而非常规 token 则主要由 draft model 处理。

消融实验

PRM 质量影响

PRM 配置 GSM8K 准确率 FLOPs 降低
无 PRM (基线) 68.5% 1.0x
随机 PRM 65.2% 2.1x
弱监督 PRM 69.8% 3.5x
全监督 PRM 72.0% 4.4x

偏置参数影响

Bias Strength 准确率 FLOPs 降低
0.0 (无偏置) 70.5% 3.8x
0.2 71.2% 4.1x
0.3 72.0% 4.4x
0.5 71.8% 4.2x
1.0 70.2% 3.5x

结论:适中的偏置强度 (0.3) 提供最佳平衡。

实践指南

集成 RSD

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from rsd import RewardGuidedGenerator, ProcessRewardModel

# 1. 加载模型
draft_model = AutoModelForCausalLM.from_pretrained("TinyLlama-1.1B")
target_model = AutoModelForCausalLM.from_pretrained("Llama-2-7B")
prm = ProcessRewardModel.from_pretrained("your-prm-checkpoint")

# 2. 创建 RSD 生成器
generator = RewardGuidedGenerator(
draft_model=draft_model,
target_model=target_model,
prm=prm,
reward_threshold=0.5,
bias_parameter=0.2
)

# 3. 推理
input_text = "解方程:2x + 5 = 13"
input_ids = tokenizer.encode(input_text, return_tensors="pt")

output_ids = generator.generate(input_ids, max_length=200)
output_text = tokenizer.decode(output_ids[0])

print(output_text)

最佳实践

场景 推荐配置 预期收益
数学推理 高 reward 阈值 (0.6) 4-5x 加速
代码生成 中 reward 阈值 (0.4) 3-4x 加速
创意写作 低 reward 阈值 (0.3) 4-5x 加速
低延迟应用 高 bias 强度 (0.5) 更快响应

个人评价

RSD 是推测解码领域的重要创新。其核心贡献在于:

优势:

  1. 质量 - 效率平衡: 不盲目追求速度,而是保证输出质量
  2. 推理引导: 特别适合需要多步推理的复杂任务
  3. 动态决策: 智能分配计算资源,关键步骤用大模型
  4. 可控偏置: 允许根据应用需求调整策略

局限:

  1. PRM 训练成本: 需要额外的过程奖励模型
  2. 任务依赖: 最佳阈值和偏置参数因任务而异
  3. 系统集成: 需要同时部署三个模型,复杂度较高

适用场景:

  • 数学推理和问题求解
  • 代码生成和调试
  • 需要多步推理的复杂任务
  • 计算资源受限的高质量推理

评分: 4.2/5.0

技术亮点: reward-guided decoding, process reward model, speculative decoding, controlled bias

代码仓库: GitHub

相关资源:

© 2026 Generative AI Discovery All Rights Reserved.
Theme by hiero