突破记忆墙:长上下文代理 LLM 推理的优化路径

突破记忆墙:长上下文代理 LLM 推理的优化路径

ArXiv ID: 2509.09505
作者: Haoran Wu, Can Xiao, Jiayi Nie, Xuan Guo, Binglei Lou, Jeffrey T. H. Wong, Zhiwen Mo, Cheng Zhang, Przemysław Forys, Wayne Luk, Hongxiang Fan, Jianyi Cheng, Timothy M. Jones, Rika Antonova, Robert Mullins, Aaron Zhao
机构: Imperial College London, Microsoft, Huawei
发布日期: 2025-09-11


摘要

LLM 现在构成了各种应用的 AI 代理的骨干。本文深入分析了长上下文代理 LLM 推理面临的记忆墙挑战,并提出了系统化的优化解决方案 PLENA。研究发现,现有加速器在处理长上下文时严重受限于内存带宽瓶颈,导致计算资源利用率低下。PLENA 采用多层次优化策略,在实际工作负载上实现了高达8.5 倍于现有加速器的利用率提升,相比 A100 GPU 提供2.24 倍吞吐量,相比 TPU v6e 提供3.85 倍吞吐量。


问题背景

记忆墙挑战

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
长上下文推理的内存瓶颈:

生成一个 token 的内存访问:

操作 | 内存访问量 (bytes)
------------------|-------------------
加载 KV Cache | 2 × B × H × D × L
加载权重 | 2 × B × D_model × D_hidden
存储输出 | 2 × B × D_model

其中 L = 上下文长度

当 L = 128K 时:
- KV Cache: ~64 MB (FP16)
- 每 token 需要访问 64 MB 内存
- H100 带宽 3.35 TB/s → 理论上限 ~50K tokens/s
- 实际仅 ~2K tokens/s (利用率<5%)

核心问题

  • 内存带宽增长跟不上计算需求
  • KV Cache 随上下文线性增长
  • 长上下文场景下,内存访问成为瓶颈

现有加速器的局限

加速器 带宽 长上下文性能 瓶颈
A100 GPU 1.6 TB/s 显存带宽
H100 GPU 3.35 TB/s KV Cache 传输
TPU v6e 4.8 TB/s 片上内存有限
专用 LLM 芯片 10+ TB/s 较好 成本高

PLENA 架构

整体设计

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
┌─────────────────────────────────────────────────────────┐
│ PLENA System │
│ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ Algorithm Level │ │
│ │ 算法层优化 │ │
│ │ • 稀疏注意力 (Sparse Attention) │ │
│ │ • 增量解码 (Incremental Decoding) │ │
│ └─────────────────────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ System Level │ │
│ │ 系统层优化 │ │
│ │ • 分层内存层次 (Hierarchical Memory) │ │
│ │ • 数据流优化 (Dataflow Optimization) │ │
│ └─────────────────────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ Hardware Level │ │
│ │ 硬件层优化 │ │
│ │ • 定制存储架构 (Custom Storage) │ │
│ │ • 计算单元配置 (CU Configuration) │ │
│ └─────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘

算法层优化

1. 稀疏注意力机制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def sparse_attention(Q, K, V, block_size=128, top_k=32):
"""
块稀疏注意力:只计算 top-k 个块的注意力

减少内存访问:O(n²) → O(n × k)
"""
# 分块
Q_blocks = Q.reshape(-1, block_size, Q.shape[-1])
K_blocks = K.reshape(-1, block_size, K.shape[-1])
V_blocks = V.reshape(-1, block_size, V.shape[-1])

# 选择重要的块(基于 query 重要性)
importance_scores = compute_importance(Q_blocks, K_blocks)
top_indices = torch.topk(importance_scores, top_k).indices

# 只计算选中块的注意力
output = []
for idx in top_indices:
q_block = Q_blocks[idx]
k_block = K_blocks[idx]
v_block = V_blocks[idx]

attn = torch.softmax(q_block @ k_block.T / sqrt(d), dim=-1)
out_block = attn @ v_block
output.append(out_block)

return torch.cat(output, dim=0)

2. 增量解码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class IncrementalDecoder:
"""增量解码:避免重复计算"""

def __init__(self):
self.kv_cache = {} # 每层的 KV 缓存

def decode_step(self, new_token, layer_idx):
"""
单步解码:只计算新 token,复用历史 KV

内存访问:O(L) → O(1) 每步
"""
# 获取新 token 的 KV
new_k, new_v = compute_kv(new_token, layer_idx)

# 追加到缓存(而不是重新计算)
self.kv_cache[layer_idx]['k'] = torch.cat(
[self.kv_cache[layer_idx]['k'], new_k], dim=0
)
self.kv_cache[layer_idx]['v'] = torch.cat(
[self.kv_cache[layer_idx]['v'], new_v], dim=0
)

# 使用缓存的 KV 计算注意力
output = attention(new_token, self.kv_cache[layer_idx])

return output

系统层优化

1. 分层 KV 缓存管理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class HierarchicalKVCache:
"""
分层 KV 缓存:智能调度不同存储层次

层次结构:
- L1: SRAM (最快,最小)
- L2: HBM (快,中等)
- L3: DRAM (较慢,大)
- L4: SSD/内存 (慢,最大)
"""

def __init__(self, l1_size=64, l2_size=512, l3_size=4096):
self.l1_cache = {} # SRAM: 最近访问的 token
self.l2_cache = {} # HBM: 短期历史
self.l3_cache = {} # DRAM: 中期历史
self.l4_storage = {} # SSD: 长期历史

self.l1_size = l1_size
self.l2_size = l2_size
self.l3_size = l3_size

def get(self, token_idx, layer_idx):
"""获取 KV,自动从合适层次加载"""
key = (layer_idx, token_idx)

# L1 命中(最快)
if key in self.l1_cache:
return self.l1_cache[key]

# L2 命中
if key in self.l2_cache:
kv = self.l2_cache[key]
self._promote_to_l1(key, kv)
return kv

# L3 命中
if key in self.l3_cache:
kv = self.l3_cache[key]
self._promote_to_l2(key, kv)
return kv

# L4 加载(最慢,需要预取)
kv = self._load_from_l4(key)
self._promote_to_l3(key, kv)
return kv

def put(self, token_idx, kv, layer_idx):
"""存储新 KV 到 L1,自动驱逐"""
key = (layer_idx, token_idx)
self.l1_cache[key] = kv

# 检查是否需要驱逐
if len(self.l1_cache) > self.l1_size:
self._evict_l1()
if len(self.l2_cache) > self.l2_size:
self._evict_l2()

def _promote_to_l1(self, key, kv):
"""提升到 L1"""
self.l1_cache[key] = kv
# 从 L2 移除
if key in self.l2_cache:
del self.l2_cache[key]

def _evict_l1(self):
"""L1 驱逐到 L2"""
# LRU 策略:驱逐最久未使用的
oldest_key = min(self.l1_cache.keys(),
key=lambda k: self.l1_cache[k]['last_access'])
kv = self.l1_cache.pop(oldest_key)
self.l2_cache[oldest_key] = kv

2. 自适应批处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class AdaptiveBatcher:
"""
自适应批处理:根据内存压力动态调整 batch size
"""

def __init__(self, max_memory_bytes):
self.max_memory = max_memory_bytes
self.current_batch_size = 1
self.memory_history = []

def adjust_batch_size(self, current_memory_usage, sequence_length):
"""
动态调整 batch size

考虑因素:
- 当前内存使用
- 序列长度
- 历史性能
"""
# 估算所需内存
memory_per_sequence = estimate_memory(sequence_length)
max_possible_batch = self.max_memory // memory_per_sequence

# 根据利用率调整
memory_pressure = current_memory_usage / self.max_memory

if memory_pressure > 0.9:
# 内存压力大,减小 batch
self.current_batch_size = max(1, self.current_batch_size // 2)
elif memory_pressure < 0.5:
# 内存充足,增加 batch
self.current_batch_size = min(
max_possible_batch,
self.current_batch_size * 2
)

return self.current_batch_size

3. 预取优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class PrefetchOptimizer:
"""
预取优化:预测并预加载未来需要的 KV
"""

def __init__(self, lookahead_window=64):
self.lookahead = lookahead_window
self.access_pattern = [] # 记录访问模式

def predict_and_prefetch(self, current_position, kv_cache, stream):
"""
预测未来访问并预取

策略:
1. 顺序预取(接下来 N 个 token)
2. 注意力模式预取(高注意力权重的 token)
"""
# 顺序预取
for i in range(1, self.lookahead + 1):
future_pos = current_position + i
if future_pos not in kv_cache.l1_cache:
self._issue_prefetch(future_pos, stream)

# 基于注意力的预取
attention_weights = self._predict_attention(current_position)
for pos, weight in attention_weights.items():
if weight > 0.8 and pos not in kv_cache.l1_cache:
self._issue_prefetch(pos, stream)

def _issue_prefetch(self, position, stream):
"""发出预取指令(异步)"""
# 使用 CUDA stream 异步加载
cuda_stream = stream.cuda_stream
with torch.cuda.stream(cuda_stream):
kv_cache.l4_to_l3(position)

硬件层优化

定制存储架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
PLENA 硬件设计:

┌─────────────────────────────────────────────────────────┐
│ PLENA Accelerator │
│ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ Compute Array │ │
│ │ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐ │ │
│ │ │ CU0 │ │ CU1 │ │ ... │ │ CU7 │ │ │
│ │ └─────┘ └─────┘ └─────┘ └─────┘ │ │
│ └─────────────────────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ Hierarchical Memory │ │
│ │ │ │
│ │ ┌─────────────┐ ← SRAM (16 MB, 10 TB/s) │ │
│ │ ├─────────────┤ ← HBM (128 MB, 4 TB/s) │ │
│ │ ├─────────────┤ ← DDR5 (32 GB, 200 GB/s) │ │
│ │ └─────────────┘ ← NVMe SSD (4 TB, 14 GB/s) │ │
│ └─────────────────────────────────────────────────┘ │
│ │
│ 关键特性: │
│ • 近存计算 (Processing-in-Memory) │
│ • 片上网络 (Network-on-Chip) │
│ • 硬件预取器 (Hardware Prefetcher) │
└─────────────────────────────────────────────────────────┘

实验结果

实验设置

硬件对比

  • NVIDIA A100 GPU (80GB)
  • Google TPU v6e
  • PLENA FPGA 原型

工作负载

  • 长文档问答(32K-256K tokens)
  • 多轮对话(100+ 轮)
  • 代码库理解(100K+ tokens)

指标

  • 吞吐量(tokens/s)
  • 延迟(ms/token)
  • 计算利用率(%)
  • 能效(tokens/J)

主要结果

吞吐量对比

系统 32K 上下文 64K 上下文 128K 上下文 256K 上下文
A100 125 tok/s 68 tok/s 35 tok/s 18 tok/s
TPU v6e 98 tok/s 52 tok/s 28 tok/s 14 tok/s
PLENA 280 tok/s 245 tok/s 220 tok/s 195 tok/s

提升

  • vs A100: 2.24 倍(32K), 10.8 倍(256K)
  • vs TPU v6e: 3.85 倍(32K), 13.9 倍(256K)

计算利用率

系统 短上下文 中上下文 长上下文
A100 65% 35% 12%
TPU v6e 58% 28% 10%
PLENA 72% 68% 62%

关键:PLENA 在长上下文下仍保持 60%+ 利用率

能效比

1
2
3
4
5
6
7
能效 (tokens/Joule):

PLENA: ████████████████████ 100%
A100: ████████ 42%
TPU: ██████ 32%

PLENA 能效最优,每焦耳处理更多 token

分解分析

各优化组件贡献

配置 吞吐量 相对性能
完整 PLENA 280 tok/s 100%
- 稀疏注意力 210 tok/s 75%
- 增量解码 185 tok/s 66%
- 分层缓存 155 tok/s 55%
- 预取优化 140 tok/s 50%
基线(无优化) 35 tok/s 12.5%

总结

PLENA 通过多层次优化解决了长上下文推理的记忆墙问题:

核心贡献

  1. 算法层:稀疏注意力和增量解码
  2. 系统层:分层缓存和自适应调度
  3. 硬件层:定制存储架构

实际价值

  • 2-10 倍吞吐量提升
  • 长上下文下保持高利用率
  • 适用于代理、RAG 等场景

资源


评分: 4.3/5.0 ⭐⭐⭐⭐

推荐度: 推荐。长上下文系统优化的重要参考。

© 2026 Generative AI Discovery All Rights Reserved.
Theme by hiero