NSA:DeepSeek原生稀疏注意力机制——硬件对齐的高效长上下文方案

NSA:DeepSeek原生稀疏注意力机制——硬件对齐的高效长上下文方案

ArXiv ID: 2502.11089

作者: Jingyang Yuan, Huazuo Gao, Damai Dai, Junyu Luo, Liang Zhao等

机构: DeepSeek-AI, 北京大学, 华盛顿大学

发布日期: 2025年2月

摘要

随着大语言模型的上下文窗口不断扩大(64K甚至更长),标准的全注意力机制在解码阶段成为严重的性能瓶颈——理论估计显示,64K上下文长度下softmax attention计算占总延迟的70-80%。DeepSeek团队提出NSA(Native Sparse Attention),一种硬件对齐的、可原生训练的稀疏注意力机制。NSA通过动态层级稀疏策略,将粗粒度的token压缩与细粒度的token选择相结合,在保持全注意力模型精度的同时,在64K序列上实现了显著的训练和推理加速。更关键的是,NSA支持端到端的预训练,而非仅作为推理阶段的后处理优化。

核心问题:长上下文注意力的计算墙

1. 全注意力的代价

标准Multi-Head Attention的时间复杂度为O(N^2 * d),其中N为序列长度,d为隐藏维度。当N达到64K时:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 64K上下文的计算量估算
seq_len = 65536
hidden_dim = 4096
num_heads = 32
head_dim = hidden_dim // num_heads # 128

# 单层attention的FLOPs
attention_flops = 2 * seq_len * seq_len * head_dim * num_heads
# = 2 * 65536 * 65536 * 128 * 32 = 3.5e13 FLOPs

# 80层Transformer
total_flops = attention_flops * 80 # = 2.8e15 FLOPs

# H100 GPU FP16峰值: ~2000 TFLOPs
# 理论最短时间: 2.8e15 / 2e15 = 1.4秒/step
# 实际由于内存带宽: 5-8秒/step

关键观察:在解码阶段,每生成一个token都需要与完整的KV cache做注意力计算,但实际上大部分attention权重极为稀疏——只有少数关键token贡献了主要的注意力分数。

2. 现有稀疏方法的局限

现有的稀疏注意力方法(如StreamingLLM、H2O等)多为推理时的后处理手段:

  • 无法训练:推理时才引入稀疏,训练时仍用全注意力
  • 精度损失大:启发式的token淘汰策略容易丢失关键信息
  • 硬件不友好:不规则的稀疏模式难以充分利用GPU Tensor Core

NSA的核心洞察:稀疏策略应从训练阶段就原生嵌入模型,让模型学会在稀疏条件下有效利用注意力资源。

NSA架构:三分支层级稀疏

1. 整体设计

NSA将全注意力替换为三个并行的稀疏分支,通过学习门控机制动态组合:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
输入: Q, K, V (来自标准投影)

├─> 压缩分支 (Compression Branch)
│ 粗粒度:块级别token压缩,捕获全局上下文
│ 复杂度: O(N/B * d),B为块大小

├─> 选择分支 (Selection Branch)
│ 细粒度:动态选择top-k重要token块
│ 复杂度: O(k * B * d),k为选择块数

├─> 滑动窗口分支 (Sliding Window Branch)
│ 局部精度:保留近邻token的完整注意力
│ 复杂度: O(W * d),W为窗口大小

└─> 门控融合 (Gated Aggregation)
O = g_c * O_compress + g_s * O_select + g_w * O_window
门控权重由Q计算得到

2. 压缩分支:全局上下文捕获

压缩分支将连续的token块压缩为单个表示,实现对全局信息的高效概览:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def compression_branch(Q, K, V, block_size=32):
"""
将KV按块压缩,捕获全局上下文模式
"""
B, H, N, D = K.shape

# 将K, V按block_size分块
num_blocks = N // block_size
K_blocks = K.reshape(B, H, num_blocks, block_size, D)
V_blocks = V.reshape(B, H, num_blocks, block_size, D)

# 使用可学习的MLP压缩每个块
# 输入: (block_size, D) -> 输出: (1, D)
K_compressed = compress_mlp_k(K_blocks) # (B, H, num_blocks, D)
V_compressed = compress_mlp_v(V_blocks) # (B, H, num_blocks, D)

# 在压缩后的表示上做标准attention
attn_scores = torch.matmul(Q, K_compressed.transpose(-2, -1))
attn_scores = attn_scores / math.sqrt(D)
attn_weights = F.softmax(attn_scores, dim=-1)

output = torch.matmul(attn_weights, V_compressed)
return output

设计思考

  • 块大小B=32经验证在压缩率和信息保留之间取得最佳平衡
  • 压缩MLP是可学习的,模型自适应决定如何压缩
  • 复杂度从O(N)降至O(N/B),32倍压缩

3. 选择分支:关键token精确定位

选择分支在压缩分支提供的全局视图基础上,进一步精确选择最重要的token块:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def selection_branch(Q, K, V, K_compressed, top_k=16, block_size=32):
"""
基于压缩表示,选择top-k最重要的块
然后在原始粒度上计算精确attention
"""
B, H, N, D = K.shape

# Step 1: 使用压缩KV计算粗粒度重要性
coarse_scores = torch.matmul(Q, K_compressed.transpose(-2, -1))
# coarse_scores: (B, H, N_q, num_blocks)

# Step 2: 选择top-k个最重要的块
_, top_indices = torch.topk(coarse_scores, k=top_k, dim=-1)
# top_indices: (B, H, N_q, top_k)

# Step 3: 收集被选中块的原始K, V
K_blocks = K.reshape(B, H, -1, block_size, D)
V_blocks = V.reshape(B, H, -1, block_size, D)

# 使用gather获取选中的块
selected_K = gather_blocks(K_blocks, top_indices)
# selected_K: (B, H, N_q, top_k * block_size, D)

selected_V = gather_blocks(V_blocks, top_indices)

# Step 4: 在选中的原始token上做精确attention
attn_scores = torch.matmul(Q, selected_K.transpose(-2, -1))
attn_scores = attn_scores / math.sqrt(D)
attn_weights = F.softmax(attn_scores, dim=-1)

output = torch.matmul(attn_weights, selected_V)
return output

两阶段策略的优势

  1. 第一阶段用压缩表示快速扫描全局,复杂度O(N/B)
  2. 第二阶段仅在选中的k个块上做精确计算,复杂度O(k*B)
  3. 总复杂度O(N/B + k*B),远小于全注意力的O(N)

4. 门控融合

三个分支的输出通过可学习的门控网络动态融合:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def gated_aggregation(Q, O_compress, O_select, O_window):
"""
基于查询Q动态决定各分支的权重
"""
# 门控权重由Q经过线性变换得到
gate_logits = gate_proj(Q) # (B, H, N, 3)
gates = torch.sigmoid(gate_logits)

g_c, g_s, g_w = gates.unbind(dim=-1)

# 加权融合
output = (g_c.unsqueeze(-1) * O_compress +
g_s.unsqueeze(-1) * O_select +
g_w.unsqueeze(-1) * O_window)

return output

门控学习的效果

  • 浅层倾向于更多使用滑动窗口(局部模式)
  • 深层倾向于更多使用选择分支(全局依赖)
  • 不同任务类型的门控模式有显著差异

硬件对齐的Triton Kernel

1. 内存访问模式优化

NSA的Kernel设计充分考虑了GPU架构特性:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Kernel设计 (简化描述):
┌───────────────────────────────────┐
│ Grid Loop: 按GQA组遍历查询 │
│ ┌───────────────────────────────┐│
│ │ Inner Loop: 加载稀疏KV块 ││
│ │ ┌───────────────────────────┐││
│ │ │ SRAM计算: │││
│ │ │ 1. 加载Q到寄存器 │││
│ │ │ 2. 加载选中K块到SRAM │││
│ │ │ 3. 计算QK^T (Tensor Core)│││
│ │ │ 4. 加载V块到SRAM │││
│ │ │ 5. 计算softmax + output │││
│ │ └───────────────────────────┘││
│ └───────────────────────────────┘│
└───────────────────────────────────┘

关键优化策略

  1. 块对齐访问:所有稀疏块大小为32的倍数,与Tensor Core的tile大小对齐
  2. GQA分组处理:同一组内的查询共享KV,减少重复加载
  3. SRAM驻留计算:整个attention计算在SRAM中完成,避免反复访问HBM
  4. 连续内存布局:选中的KV块重排为连续内存,最大化带宽利用

2. 前向和反向传播一体化

NSA同时提供了高效的反向传播Kernel,这是它能端到端训练的关键:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 反向传播也使用稀疏计算
# 梯度仅沿着选中的稀疏路径传播
# 门控权重的梯度通过直通估计器(STE)近似

class NSAFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, Q, K, V, ...):
# 稀疏前向计算
output, metadata = nsa_forward_kernel(Q, K, V, ...)
ctx.save_for_backward(Q, K, V, metadata)
return output

@staticmethod
def backward(ctx, grad_output):
Q, K, V, metadata = ctx.saved_tensors
# 稀疏反向计算:仅在前向选中的块上计算梯度
grad_Q, grad_K, grad_V = nsa_backward_kernel(
grad_output, Q, K, V, metadata
)
return grad_Q, grad_K, grad_V

训练配置与实验

1. 预训练设置

NSA在DeepSeek的训练框架上验证,采用MoE架构:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# 模型配置
architecture: Transformer + NSA + DeepSeekMoE
hidden_dim: 2048
num_layers: 27
num_attention_heads: 16
num_kv_heads: 2 # GQA
moe:
num_routed_experts: 72
num_shared_experts: 2
top_k: 6

# NSA配置
nsa:
compression_block_size: 32
selection_top_k: 16
sliding_window_size: 512
gate_type: "sigmoid"

# 训练配置
training_tokens: 260B
max_context_length: 64K
batch_size: variable (sequence packing)
learning_rate: 3e-4
optimizer: AdamW

2. 性能结果

预训练Loss对比

NSA的预训练loss曲线与全注意力基线保持一致,且在部分区间甚至略优于全注意力模型,说明稀疏策略没有损害模型的学习能力。

1
2
3
训练Loss (260B tokens):
Full Attention: 2.87 -> 2.34 (稳定收敛)
NSA: 2.87 -> 2.33 (稳定收敛,略优)

推理加速比 (64K序列长度):

操作 Full Attention NSA 加速比
Decoding 基线 - 显著加速
Forward 基线 - 显著加速
Backward 基线 - 显著加速

加速比随序列长度增长而更加明显,在64K上下文下优势最为突出。

下游任务精度

任务类别 Full Attention NSA 差异
通用基准 基线 持平或更优 ~0
长上下文任务 基线 持平或更优 ~0
指令遵循 基线 持平 ~0

3. Chain-of-Thought推理验证

为验证NSA与高级训练范式的兼容性,作者使用DeepSeek-R1的数学推理数据(10B tokens,32K长度)进行蒸馏训练:

1
2
3
AIME 24数学推理评估:
Full Attention-R: 基线
NSA-R: 持平或更优

这证明NSA不仅适用于基础预训练,也能完美支持复杂的推理能力训练。

技术对比分析

与其他稀疏注意力方法对比

方法 可训练 硬件对齐 前向加速 反向加速 精度保持
NSA 优秀
StreamingLLM 部分 中等
H2O 中等
Longformer 部分 良好
FlashAttention 否(全注意力) 基线 基线 完美

核心优势

  1. 端到端可训练:从预训练就嵌入稀疏策略,模型学会有效利用
  2. 硬件对齐:Triton实现,块对齐内存访问,Tensor Core友好
  3. 三分支互补:全局压缩+精确选择+局部窗口,覆盖所有注意力模式
  4. MoE兼容:已在DeepSeekMoE架构上验证

部署建议

适用场景

  1. 长上下文推理 – 推荐度: 极高

    • 64K+ token输入的代码理解、文档分析
    • 注意力计算占比超过70%时效果最显著
  2. 大规模预训练 – 推荐度: 高

    • 从头开始训练新模型时直接使用NSA
    • 节省训练计算成本的同时不损失精度
  3. 推理服务优化 – 推荐度: 高

    • 结合vLLM/SGLang等serving框架
    • 降低长上下文请求的延迟

注意事项

  1. 短序列场景收益有限:序列长度<4K时,全注意力本身开销不大
  2. 需要从头训练或微调:不能直接替换已训练好模型的注意力层
  3. 超参数调优:块大小、top-k选择数需要根据具体任务调整

个人评价

NSA代表了稀疏注意力研究的一个重要里程碑。与之前大量”推理时后处理”的稀疏方法不同,NSA从根本上解决了问题——让模型在训练阶段就学会稀疏注意力模式。三分支设计(压缩+选择+窗口)的架构既有理论优雅性,也有工程实用性。

DeepSeek团队在这篇论文中展示了从算法设计到硬件优化的端到端工程能力。特别是Triton Kernel的设计,充分考虑了GPU内存层次和Tensor Core特性,使得理论上的计算节省能真正转化为实际的速度提升。

NSA对后续研究的启发在于:稀疏注意力不应该是训练后的补丁,而应该是架构设计的一部分。随着上下文窗口不断增大(128K、1M甚至更长),这种思路将越来越重要。

不足之处

  • 需要从头预训练,无法直接应用于现有模型
  • 稀疏策略的超参数(块大小、top-k等)需要手动设定
  • 论文中的实验规模虽然可观但不是最大规模

评分: 4.75/5.0

论文: https://arxiv.org/abs/2502.11089

Hugging Face: https://huggingface.co/papers/2502.11089

© 2026 Generative AI Discovery All Rights Reserved.
Theme by hiero