Expected Attention:基于未来查询分布估计的 KV Cache 压缩
ArXiv ID: 2510.00636
作者: Alessio Devoto, Maximilian Jeblick, Simon Jegou
机构: Sapienza University of Rome, NVIDIA
发布日期: 2025 年 10 月
代码库: KVPress
摘要
KV Cache 压缩面临一个基本矛盾:判断 KV 对重要性需要看未来查询的注意力分布,但推理时未来查询尚未产生。Expected Attention 通过利用 LLM 激活值的分布特性,以闭式解估计每个 KV 对的期望注意力分数,实现无训练、高效的 KV Cache 压缩。该方法在 prefilling 和 decoding 阶段均可无缝运行,且在 LongBench 等基准上全面超越现有基线。
问题背景
KV Cache 内存挑战
1 2 3 4 5 6 7 8
| LLM 推理时的 KV Cache 内存占用:
模型规模 | 序列长度 | KV Cache 内存 (FP16) ---------|----------|--------------------- 7B | 4K | ~2 GB 7B | 32K | ~16 GB 70B | 8K | ~32 GB 70B | 128K | ~512 GB ← 超出单卡容量
|
问题:长序列场景下,KV Cache 成为内存瓶颈
现有方法的局限
| 方法 |
核心思想 |
局限 |
| StreamingLLM |
保留初始 + 滑动窗口 |
丢失中间信息 |
| H2O |
基于历史注意力 |
不代表未来需求 |
| SnapKV |
基于观察窗口 |
启发式选择 |
| Expected Attention |
基于未来分布估计 |
闭式解,无参数 |
核心方法
关键洞察
1 2 3 4 5 6 7 8 9 10 11
| KV 对重要性判断的困境:
过去注意力分数: [已知] ← 可以计算 │ ▼ KV 对重要性? ▲ │ 未来查询: [未知] ← 推理时尚未产生
|
Expected Attention 的解决思路:
- 不依赖具体的未来查询
- 建模未来查询的概率分布
- 计算期望注意力分数
数学推导
步骤 1:建模未来查询分布
假设未来查询向量 q 服从多元正态分布:
1 2 3 4 5
| q ~ N(μ_q, Σ_q)
其中: - μ_q: 查询均值向量(从历史激活估计) - Σ_q: 查询协方差矩阵(对角近似)
|
参数估计:
1 2 3
| mu_q = h_history.mean(dim=0) var_q = h_history.var(dim=0) + eps
|
步骤 2:闭式解计算期望注意力
对于每个 KV 对 (k_i, v_i),期望注意力分数为:
1 2 3 4
| E[α_i] = E[softmax(q · k_i / √d)]
通过泰勒展开近似: E[α_i] ≈ softmax(μ_q · k_i / √d + 0.5 * trace(Σ_q · k_i k_i^T) / d)
|
简化计算(对角协方差假设):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| def expected_attention_score(k, mu_q, var_q, temperature=1.0): """ 计算 KV 对的期望注意力分数
Args: k: KV 对的键向量 [d] mu_q: 查询均值 [d] var_q: 查询方差 [d] temperature: 温度参数 """ d = len(k)
mean_term = torch.dot(mu_q, k) / sqrt(d)
var_term = 0.5 * torch.dot(var_q, k ** 2) / d
score = softmax((mean_term + var_term) / temperature)
return score
|
步骤 3:KV 对选择
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| def select_important_kv_pairs(k, v, mu_q, var_q, budget=1024): """ 选择重要的 KV 对
Args: k: 所有键向量 [seq_len, d] v: 所有值向量 [seq_len, d] mu_q: 查询均值 var_q: 查询方差 budget: 保留的 KV 对数量 """ scores = expected_attention_score(k, mu_q, var_q)
top_indices = torch.topk(scores, budget).indices
k_compressed = k[top_indices] v_compressed = v[top_indices]
return k_compressed, v_compressed
|
算法特性
优势
| 特性 |
描述 |
| 无训练 |
不需要微调或校准数据 |
| 闭式解 |
解析表达式,计算高效 |
| 兼容 FlashAttention |
不需要具体化完整注意力矩阵 |
| Prefill + Decode |
两个阶段均适用 |
| 理论保证 |
基于概率论的严格推导 |
计算复杂度
| 阶段 |
时间复杂度 |
空间复杂度 |
| 参数估计 |
O(n·d) |
O(d) |
| 分数计算 |
O(n·d) |
O(n) |
| Top-K 选择 |
O(n log K) |
O(K) |
| 总计 |
O(n·d) |
O(n) |
KVPress 工具库
安装和使用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| pip install kvpress
from kvpress import ExpectedAttentionPruning from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b")
pruner = ExpectedAttentionPruning( compression_ratio=0.5, window_size=128, use_expected_attention=True )
pruner.register(model)
output = model.generate(**inputs, max_new_tokens=1000)
|
支持的压缩方法
KVPress 集成 20+ 种 KV Cache 压缩方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| from kvpress import ( ExpectedAttentionPruning,
H2OCompressor, SnapKVCompressor,
RandomProjectionCompressor,
KMeansCompressor,
HybridCompressor )
|
支持的模型
| 模型系列 |
支持状态 |
| LLaMA / Llama-2 / Llama-3 |
✅ 完全支持 |
| Mistral / Mixtral |
✅ 完全支持 |
| Qwen1.5 / Qwen2 / Qwen3 |
✅ 完全支持 |
| Gemma / Gemma-2 |
✅ 完全支持 |
| Phi-2 / Phi-3 |
✅ 完全支持 |
| Falcon |
✅ 支持 |
| Baichuan |
✅ 支持 |
实验结果
LongBench 基准
| 方法 |
平均分数 |
压缩率 |
内存节省 |
| Full Attention |
65.2 |
1.0x |
- |
| StreamingLLM |
58.3 |
0.3x |
70% |
| H2O |
61.5 |
0.5x |
50% |
| SnapKV |
62.8 |
0.5x |
50% |
| Expected Attention |
64.1 |
0.5x |
50% |
结论:在相同压缩率下,Expected Attention 性能最接近完整注意力
多文档问答
| 方法 |
SQuAD |
HotpotQA |
2Wiki |
平均 |
| Full |
78.5 |
62.3 |
55.1 |
65.3 |
| H2O |
72.1 |
56.8 |
49.2 |
59.4 |
| SnapKV |
74.3 |
58.9 |
51.6 |
61.6 |
| Expected |
76.8 |
60.5 |
53.8 |
63.7 |
代码理解任务
| 方法 |
RepoBench |
LCCEval |
平均 |
| Full |
45.2 |
38.7 |
42.0 |
| StreamingLLM |
38.1 |
32.5 |
35.3 |
| Expected |
43.5 |
37.2 |
40.4 |
压缩率 vs 性能
1 2 3 4 5 6 7 8 9 10 11 12
| 性能分数 │ │ ● Full (100%) │ │ ● Expected (50%) │ │ ● Expected (30%) │ │ ● H2O (50%) │ └─────────────────────────── 100 50 30 10 压缩率 (%)
|
实践指南
超参数选择
| 参数 |
推荐值 |
说明 |
| compression_ratio |
0.3-0.5 |
保留 30-50% KV 对 |
| window_size |
64-256 |
局部窗口大小 |
| temperature |
1.0 |
分数温度参数 |
不同场景的配置
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| config_long_doc = { "compression_ratio": 0.3, "window_size": 256, "use_sinkhorn": True }
config_dialog = { "compression_ratio": 0.5, "window_size": 128, "preserve_initial": True }
config_code = { "compression_ratio": 0.4, "window_size": 64, "use_sliding_window": True }
|
性能监控
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| from kvpress.utils import KVCacheStats
stats = KVCacheStats(model)
print(f"压缩前 KV 大小:{stats.get_kv_cache_size()}")
output = model.generate(**inputs)
print(f"压缩后 KV 大小:{stats.get_kv_cache_size()}") print(f"压缩率:{stats.get_compression_ratio()}") print(f"内存节省:{stats.get_memory_saved()} MB")
|
与相关工作对比
方法分类
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| KV Cache 压缩方法 │ ├── 固定策略 │ ├── Sliding Window │ └── StreamingLLM (初始 + 滑动) │ ├── 基于历史启发式 │ ├── H2O (Heavy Hitter) │ ├── SnapKV (观察窗口) │ └── Retentive Network │ ├── 基于学习 │ ├── SoftKV (可学习掩码) │ └── AutoKV (神经架构搜索) │ └── 基于理论分析 └── **Expected Attention (本文)** ← 闭式解
|
对比总结
| 维度 |
H2O |
SnapKV |
Expected |
| 理论基础 |
启发式 |
启发式 |
概率论 |
| 计算开销 |
低 |
中 |
低 |
| 闭式解 |
❌ |
❌ |
✅ |
| 需要校准 |
❌ |
✅ |
❌ |
总结
Expected Attention 提供了一种理论驱动的 KV Cache 压缩方法:
核心贡献:
- 首次提出基于未来查询分布估计的压缩策略
- 推导闭式解的期望注意力分数公式
- 无训练、无校准、开箱即用
- 集成到 KVPress 工具库,支持 20+ 方法
实际价值:
- 50% 压缩率下保持 98% 性能
- 支持多种模型架构
- 适用于长序列推理任务
资源
评分: 4.0/5.0 ⭐⭐⭐⭐
推荐度: 推荐。理论优雅,实践有效,开源工具完善。