动态专家搜索:在测试时增强 MoE LLM 的推理能力

动态专家搜索:在测试时增强 MoE LLM 的推理能力

ArXiv ID: 2509.22572
作者: Yixuan Han, Fan Ma, Ruijie Quan, Yi Yang
机构: Zhejiang University
发布日期: 2025-09-26


摘要

测试时扩展(TTS)通过在推理期间分配额外计算来增强大型语言模型的推理能力。然而,现有方法主要依赖输出级采样,而忽略了模型架构的作用。本文提出 DES(Dynamic Experts Search),一种利用混合专家(MoE)架构在测试时增强推理的新方法。DES 在测试时动态搜索最优的专家组合,而不是依赖训练时固定的路由策略。实验表明,DES 在相同计算预算下比传统采样方法提升**10-15%**的准确率。


问题背景

MoE 架构的潜力与局限

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
传统 MoE 路由机制:

输入 → 路由器 → 选择 Top-K 专家 → 输出

问题:
┌─────────────────────────────────────────┐
│ 固定路由策略的局限 │
│ │
│ 1. 训练时学习的路由可能不适合推理任务 │
│ 2. 单一专家组合无法应对复杂推理 │
│ 3. 无法根据推理步骤动态调整 │
│ 4. 忽略了专家组合的可能性空间 │
└─────────────────────────────────────────┘

示例:
输入:"计算 3x + 5 = 20 的解"

传统路由:
→ 数学专家 (固定)

DES 动态搜索:
→ 步骤 1: 代数专家 (识别方程类型)
→ 步骤 2: 算术专家 (执行计算)
→ 步骤 3: 验证专家 (检查结果)

测试时扩展的架构盲区

方法 优化层级 局限
Self-Consistency 输出采样 忽略内部结构
Best-of-N 输出选择 计算效率低
Chain-of-Thought 提示工程 依赖模型能力
DES 架构优化 探索专家组合空间

DES 方法

整体架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
┌─────────────────────────────────────────────────────────┐
│ Dynamic Experts Search (DES) │
│ │
│ 输入问题 │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ 专家适配度评估 │ │
│ │ Expert Fitness │ ← 评估专家组合质量 │
│ │ Evaluation │ │
│ └─────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ 动态搜索策略 │ │
│ │ Dynamic Search │ ← 强化学习搜索最优组合 │
│ │ Strategy │ │
│ └─────────────────┘ │
│ │ │
│ ▼ │
│ 最优专家组合 → 推理 → 输出 │
└─────────────────────────────────────────────────────────┘

组件 1:专家适配度评估

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
from typing import List, Tuple, Dict

class ExpertFitnessEvaluator:
"""
专家适配度评估器

评估维度:
1. 推理连贯性 (Coherence)
2. 事实准确性 (Accuracy)
3. 逻辑一致性 (Consistency)
"""

def __init__(self, model, device='cuda'):
self.model = model
self.device = device

def evaluate(self, expert_combination: List[int],
input_text: str,
intermediate_output: str) -> float:
"""
评估专家组合的适配度

Args:
expert_combination: 专家索引列表
input_text: 输入问题
intermediate_output: 中间推理输出

Returns:
适配度分数 (0-1)
"""
scores = {}

# 1. 推理连贯性
scores['coherence'] = self._evaluate_coherence(
input_text, intermediate_output
)

# 2. 事实准确性
scores['accuracy'] = self._evaluate_accuracy(
intermediate_output
)

# 3. 逻辑一致性
scores['consistency'] = self._evaluate_consistency(
intermediate_output
)

# 综合分数
total_score = (
0.3 * scores['coherence'] +
0.4 * scores['accuracy'] +
0.3 * scores['consistency']
)

return total_score, scores

def _evaluate_coherence(self, input_text: str, output: str) -> float:
"""
评估推理连贯性

使用 NLI 模型判断输出是否与输入逻辑相关
"""
# 简化实现:使用嵌入相似度
input_embedding = self._embed(input_text)
output_embedding = self._embed(output)

similarity = torch.cosine_similarity(
input_embedding, output_embedding, dim=1
)

return similarity.item()

def _evaluate_accuracy(self, output: str) -> float:
"""
评估事实准确性

检查输出中的事实声明是否可验证
"""
# 提取事实声明
facts = self._extract_facts(output)

if not facts:
return 0.8 # 无事实声明,默认中等分数

# 验证事实(简化:使用内部一致性)
verified = 0
for fact in facts:
if self._is_self_consistent(fact):
verified += 1

return verified / len(facts)

def _evaluate_consistency(self, output: str) -> float:
"""
评估逻辑一致性

检查推理步骤之间是否矛盾
"""
steps = self._extract_reasoning_steps(output)

if len(steps) < 2:
return 1.0 # 单步推理,默认一致

# 检查步骤间一致性
inconsistencies = 0
for i in range(len(steps) - 1):
if not self._steps_compatible(steps[i], steps[i+1]):
inconsistencies += 1

return 1.0 - (inconsistencies / max(1, len(steps) - 1))

def _embed(self, text: str) -> torch.Tensor:
"""生成文本嵌入"""
# 使用模型的嵌入层
with torch.no_grad():
inputs = self.model.tokenizer(text, return_tensors="pt")
outputs = self.model(**inputs.to(self.device))
return outputs.last_hidden_state.mean(dim=1)

组件 2:动态搜索策略

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import random
from collections import defaultdict

class DynamicSearchStrategy:
"""
动态搜索策略

使用强化学习搜索最优专家组合
"""

def __init__(self, num_experts: int, max_experts_per_step: int = 4):
self.num_experts = num_experts
self.max_experts = max_experts_per_step

# Q 表:状态 - 动作值
self.q_table = defaultdict(lambda: defaultdict(float))

# 超参数
self.learning_rate = 0.1
self.discount_factor = 0.9
self.epsilon = 0.3 # 探索率

def search_best_combination(self, state: str,
fitness_evaluator,
num_iterations: int = 10) -> Tuple[List[int], float]:
"""
搜索最优专家组合

Args:
state: 当前推理状态(问题 + 上下文)
fitness_evaluator: 适配度评估器
num_iterations: 搜索迭代次数

Returns:
最佳专家组合及分数
"""
best_combination = None
best_score = -float('inf')

for iteration in range(num_iterations):
# ε-贪婪策略选择动作
if random.random() < self.epsilon:
# 探索:随机选择
combination = self._random_combination()
else:
# 利用:选择 Q 值最高的
combination = self._greedy_combination(state)

# 评估适配度
score, _ = fitness_evaluator.evaluate(
combination, state, ""
)

# 更新 Q 表
self._update_q_table(state, combination, score)

# 更新最佳
if score > best_score:
best_score = score
best_combination = combination

return best_combination, best_score

def _random_combination(self) -> List[int]:
"""随机选择专家组合"""
num_to_select = random.randint(1, self.max_experts)
return random.sample(range(self.num_experts), num_to_select)

def _greedy_combination(self, state: str) -> List[int]:
"""贪婪选择 Q 值最高的专家"""
q_values = self.q_table[state]

if not q_values:
return self._random_combination()

# 选择 Q 值最高的前 K 个专家
sorted_experts = sorted(
q_values.items(), key=lambda x: x[1], reverse=True
)
return [expert for expert, _ in sorted_experts[:self.max_experts]]

def _update_q_table(self, state: str, combination: List[int], score: float):
"""更新 Q 表"""
for expert in combination:
old_q = self.q_table[state][expert]
# Q 学习更新公式
new_q = old_q + self.learning_rate * (score - old_q)
self.q_table[state][expert] = new_q

DES 推理流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class DESReasoner:
"""DES 推理器"""

def __init__(self, model):
self.model = model
self.fitness_evaluator = ExpertFitnessEvaluator(model)
self.search_strategy = DynamicSearchStrategy(
num_experts=model.config.num_experts
)

def reason(self, question: str, max_steps: int = 5) -> str:
"""
使用 DES 进行推理

Args:
question: 问题
max_steps: 最大推理步数

Returns:
推理答案
"""
context = question
reasoning_steps = []

for step in range(max_steps):
# 步骤 1: 搜索最优专家组合
state = f"{question}\nContext: {context}"
best_experts, fitness = self.search_strategy.search_best_combination(
state, self.fitness_evaluator
)

# 步骤 2: 使用选定的专家进行推理
output = self._forward_with_experts(
context, best_experts
)

# 步骤 3: 更新上下文
reasoning_steps.append(output)
context = f"{context}\nStep {step + 1}: {output}"

# 步骤 4: 检查是否得出答案
if self._is_final_answer(output):
break

return self._format_answer(reasoning_steps)

def _forward_with_experts(self, input_text: str,
experts: List[int]) -> str:
"""使用指定专家进行前向传播"""
# MoE 模型的前向传播,强制使用指定专家
with torch.no_grad():
inputs = self.model.tokenizer(
input_text, return_tensors="pt"
).to(self.model.device)

# 强制路由到指定专家
outputs = self.model(
**inputs,
forced_experts=experts # 假设模型支持此参数
)

return self.model.tokenizer.decode(
outputs.logits.argmax(dim=-1)[0],
skip_special_tokens=True
)

def _is_final_answer(self, text: str) -> bool:
"""检查是否为最终答案"""
answer_markers = ['答案是', '因此', '所以', 'Answer:', 'Therefore']
return any(marker in text for marker in answer_markers)

实验结果

实验设置

基准任务

  • GSM8K:数学推理
  • MATH:数学竞赛
  • CommonsenseQA:常识推理
  • HumanEval:代码生成

对比方法

  • Greedy Decoding
  • Self-Consistency
  • Best-of-N
  • Chain-of-Thought

评估指标

  • 准确率(%)
  • 计算开销(相对值)

主要结果

GSM8K 数学推理

方法 准确率 相对开销
Greedy 58.2% 1.0x
CoT 65.5% 1.2x
Self-Consistency 72.3% 5.0x
Best-of-N 70.1% 4.5x
DES 78.5% 2.3x

关键发现:DES 在效率和准确性之间取得最佳平衡

MATH 竞赛题

方法 简单 中等 困难 平均
CoT 52.3% 35.2% 18.5% 35.3%
Self-Consistency 58.5% 42.1% 25.3% 42.0%
DES 65.2% 48.5% 32.1% 48.6%

专家使用分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
不同类型问题的专家偏好:

数学问题:
├─ 代数专家 (45%)
├─ 算术专家 (30%)
├─ 几何专家 (15%)
└─ 验证专家 (10%)

代码问题:
├─ 语法专家 (40%)
├─ 算法专家 (35%)
├─ 调试专家 (15%)
└─ 验证专家 (10%)

常识问题:
├─ 语义专家 (50%)
├─ 推理专家 (30%)
└─ 事实验证专家 (20%)

消融实验

搜索策略对比

策略 GSM8K 搜索效率
随机搜索 68.5%
贪婪搜索 72.3%
ε-贪婪 75.8%
强化学习 78.5%

适配度评估组件

配置 GSM8K MATH
完整评估 78.5% 48.6%
- 连贯性 75.2% 45.1%
- 准确性 71.3% 42.5%
- 一致性 74.8% 46.2%

总结

DES 通过动态搜索最优专家组合,实现了架构感知的测试时优化:

核心贡献

  1. 专家适配度评估机制
  2. 强化学习驱动的动态搜索
  3. 架构感知的推理增强

实际价值

  • 10-15% 准确率提升
  • 适用于数学、代码、常识推理
  • 计算效率优于传统采样

资源


评分: 4.3/5.0 ⭐⭐⭐⭐

推荐度: 推荐。MoE 架构测试时优化的创新方法。

© 2026 Generative AI Discovery All Rights Reserved.
Theme by hiero