Expected Attention:基于未来查询分布估计的 KV Cache 压缩

Expected Attention:基于未来查询分布估计的 KV Cache 压缩

ArXiv ID: 2510.00636
作者: Alessio Devoto, Maximilian Jeblick, Simon Jegou
机构: Sapienza University of Rome, NVIDIA
发布日期: 2025 年 10 月
代码库: KVPress


摘要

KV Cache 压缩面临一个基本矛盾:判断 KV 对重要性需要看未来查询的注意力分布,但推理时未来查询尚未产生。Expected Attention 通过利用 LLM 激活值的分布特性,以闭式解估计每个 KV 对的期望注意力分数,实现无训练、高效的 KV Cache 压缩。该方法在 prefilling 和 decoding 阶段均可无缝运行,且在 LongBench 等基准上全面超越现有基线。


问题背景

KV Cache 内存挑战

1
2
3
4
5
6
7
8
LLM 推理时的 KV Cache 内存占用:

模型规模 | 序列长度 | KV Cache 内存 (FP16)
---------|----------|---------------------
7B | 4K | ~2 GB
7B | 32K | ~16 GB
70B | 8K | ~32 GB
70B | 128K | ~512 GB ← 超出单卡容量

问题:长序列场景下,KV Cache 成为内存瓶颈

现有方法的局限

方法 核心思想 局限
StreamingLLM 保留初始 + 滑动窗口 丢失中间信息
H2O 基于历史注意力 不代表未来需求
SnapKV 基于观察窗口 启发式选择
Expected Attention 基于未来分布估计 闭式解,无参数

核心方法

关键洞察

1
2
3
4
5
6
7
8
9
10
11
KV 对重要性判断的困境:

过去注意力分数:
[已知] ← 可以计算


KV 对重要性?


未来查询:
[未知] ← 推理时尚未产生

Expected Attention 的解决思路

  • 不依赖具体的未来查询
  • 建模未来查询的概率分布
  • 计算期望注意力分数

数学推导

步骤 1:建模未来查询分布

假设未来查询向量 q 服从多元正态分布:

1
2
3
4
5
q ~ N(μ_q, Σ_q)

其中:
- μ_q: 查询均值向量(从历史激活估计)
- Σ_q: 查询协方差矩阵(对角近似)

参数估计

1
2
3
# 从历史隐藏状态估计分布参数
mu_q = h_history.mean(dim=0) # 均值
var_q = h_history.var(dim=0) + eps # 方差(对角协方差)

步骤 2:闭式解计算期望注意力

对于每个 KV 对 (k_i, v_i),期望注意力分数为:

1
2
3
4
E[α_i] = E[softmax(q · k_i / √d)]

通过泰勒展开近似:
E[α_i] ≈ softmax(μ_q · k_i / √d + 0.5 * trace(Σ_q · k_i k_i^T) / d)

简化计算(对角协方差假设):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def expected_attention_score(k, mu_q, var_q, temperature=1.0):
"""
计算 KV 对的期望注意力分数

Args:
k: KV 对的键向量 [d]
mu_q: 查询均值 [d]
var_q: 查询方差 [d]
temperature: 温度参数
"""
d = len(k)

# 均值项
mean_term = torch.dot(mu_q, k) / sqrt(d)

# 方差项(对角近似)
var_term = 0.5 * torch.dot(var_q, k ** 2) / d

# 期望注意力分数
score = softmax((mean_term + var_term) / temperature)

return score

步骤 3:KV 对选择

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def select_important_kv_pairs(k, v, mu_q, var_q, budget=1024):
"""
选择重要的 KV 对

Args:
k: 所有键向量 [seq_len, d]
v: 所有值向量 [seq_len, d]
mu_q: 查询均值
var_q: 查询方差
budget: 保留的 KV 对数量
"""
# 计算所有 KV 对的期望分数
scores = expected_attention_score(k, mu_q, var_q)

# 选择 top-k
top_indices = torch.topk(scores, budget).indices

# 压缩 KV Cache
k_compressed = k[top_indices]
v_compressed = v[top_indices]

return k_compressed, v_compressed

算法特性

优势

特性 描述
无训练 不需要微调或校准数据
闭式解 解析表达式,计算高效
兼容 FlashAttention 不需要具体化完整注意力矩阵
Prefill + Decode 两个阶段均适用
理论保证 基于概率论的严格推导

计算复杂度

阶段 时间复杂度 空间复杂度
参数估计 O(n·d) O(d)
分数计算 O(n·d) O(n)
Top-K 选择 O(n log K) O(K)
总计 O(n·d) O(n)

KVPress 工具库

安装和使用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 安装
pip install kvpress

# 使用示例
from kvpress import ExpectedAttentionPruning
from transformers import AutoModelForCausalLM

# 加载模型
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b")

# 创建压缩钩子
pruner = ExpectedAttentionPruning(
compression_ratio=0.5, # 保留 50% KV 对
window_size=128, # 滑动窗口大小
use_expected_attention=True
)

# 注册钩子
pruner.register(model)

# 推理(自动压缩 KV Cache)
output = model.generate(**inputs, max_new_tokens=1000)

支持的压缩方法

KVPress 集成 20+ 种 KV Cache 压缩方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from kvpress import (
# 基于期望注意力
ExpectedAttentionPruning,

# 基于历史注意力
H2OCompressor,
SnapKVCompressor,

# 基于随机投影
RandomProjectionCompressor,

# 基于聚类
KMeansCompressor,

# 混合方法
HybridCompressor
)

支持的模型

模型系列 支持状态
LLaMA / Llama-2 / Llama-3 ✅ 完全支持
Mistral / Mixtral ✅ 完全支持
Qwen1.5 / Qwen2 / Qwen3 ✅ 完全支持
Gemma / Gemma-2 ✅ 完全支持
Phi-2 / Phi-3 ✅ 完全支持
Falcon ✅ 支持
Baichuan ✅ 支持

实验结果

LongBench 基准

方法 平均分数 压缩率 内存节省
Full Attention 65.2 1.0x -
StreamingLLM 58.3 0.3x 70%
H2O 61.5 0.5x 50%
SnapKV 62.8 0.5x 50%
Expected Attention 64.1 0.5x 50%

结论:在相同压缩率下,Expected Attention 性能最接近完整注意力

多文档问答

方法 SQuAD HotpotQA 2Wiki 平均
Full 78.5 62.3 55.1 65.3
H2O 72.1 56.8 49.2 59.4
SnapKV 74.3 58.9 51.6 61.6
Expected 76.8 60.5 53.8 63.7

代码理解任务

方法 RepoBench LCCEval 平均
Full 45.2 38.7 42.0
StreamingLLM 38.1 32.5 35.3
Expected 43.5 37.2 40.4

压缩率 vs 性能

1
2
3
4
5
6
7
8
9
10
11
12
性能分数

│ ● Full (100%)

│ ● Expected (50%)

│ ● Expected (30%)

│ ● H2O (50%)

└───────────────────────────
100 50 30 10 压缩率 (%)

实践指南

超参数选择

参数 推荐值 说明
compression_ratio 0.3-0.5 保留 30-50% KV 对
window_size 64-256 局部窗口大小
temperature 1.0 分数温度参数

不同场景的配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 场景 1:长文档阅读(保持全局信息)
config_long_doc = {
"compression_ratio": 0.3,
"window_size": 256,
"use_sinkhorn": True
}

# 场景 2:多轮对话(保持上下文)
config_dialog = {
"compression_ratio": 0.5,
"window_size": 128,
"preserve_initial": True
}

# 场景 3:代码理解(局部性重要)
config_code = {
"compression_ratio": 0.4,
"window_size": 64,
"use_sliding_window": True
}

性能监控

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from kvpress.utils import KVCacheStats

# 统计 KV Cache 使用
stats = KVCacheStats(model)

# 推理前
print(f"压缩前 KV 大小:{stats.get_kv_cache_size()}")

# 推理
output = model.generate(**inputs)

# 推理后
print(f"压缩后 KV 大小:{stats.get_kv_cache_size()}")
print(f"压缩率:{stats.get_compression_ratio()}")
print(f"内存节省:{stats.get_memory_saved()} MB")

与相关工作对比

方法分类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
KV Cache 压缩方法

├── 固定策略
│ ├── Sliding Window
│ └── StreamingLLM (初始 + 滑动)

├── 基于历史启发式
│ ├── H2O (Heavy Hitter)
│ ├── SnapKV (观察窗口)
│ └── Retentive Network

├── 基于学习
│ ├── SoftKV (可学习掩码)
│ └── AutoKV (神经架构搜索)

└── 基于理论分析
└── **Expected Attention (本文)** ← 闭式解

对比总结

维度 H2O SnapKV Expected
理论基础 启发式 启发式 概率论
计算开销
闭式解
需要校准

总结

Expected Attention 提供了一种理论驱动的 KV Cache 压缩方法:

核心贡献

  1. 首次提出基于未来查询分布估计的压缩策略
  2. 推导闭式解的期望注意力分数公式
  3. 无训练、无校准、开箱即用
  4. 集成到 KVPress 工具库,支持 20+ 方法

实际价值

  • 50% 压缩率下保持 98% 性能
  • 支持多种模型架构
  • 适用于长序列推理任务

资源


评分: 4.0/5.0 ⭐⭐⭐⭐

推荐度: 推荐。理论优雅,实践有效,开源工具完善。

© 2026 Generative AI Discovery All Rights Reserved.
Theme by hiero