注意力机制详解

4 分钟阅读

本文章属于专栏:AI知识

注意力机制详解

🔥 注意力机制的本质

注意力机制的核心思想:让模型学会"关注"输入的不同部分

Query: "我想关注什么?"
Key:   "每个位置能提供什么?"
Value: "每个位置的实际内容"

🔥 Self-Attention 深入理解

直观理解

想象一个句子:"The cat sat on the mat because it was tired"

当处理 "it" 这个词时:

  • Query: "it" 需要找到它的指代对象
  • Key: 每个词的"标签"(名词、动词等)
  • Value: 每个词的实际含义

注意力权重会显示 "it" 主要关注 "cat"(权重 0.8),其他词权重较低。

数学本质

# 标准 Self-Attention
Attention(Q, K, V) = softmax(QK^T / √d_k) V

# QK^T: 计算每对 token 的相似度
# softmax: 归一化为概率分布
# × V: 加权求和

本质上是一个加权平均,权重由 Q 和 K 的相似度决定。


🔥 多种注意力变体

1. Causal Attention (因果注意力)

用途: 自回归模型(GPT)

# 下三角掩码
mask = [[1, 0, 0],
        [1, 1, 0],
        [1, 1, 1]]

# 每个位置只能看到左边(包括自己)

作用: 防止"看到未来",保证自回归特性

2. Cross-Attention (交叉注意力)

用途: Encoder-Decoder 连接

# Q: 来自 Decoder
# K, V: 来自 Encoder
CrossAttention(Q_dec, K_enc, V_enc)

应用场景:

  • 机器翻译:理解源语言,生成目标语言
  • 多模态:文本查询,视觉特征

3. Multi-Query Attention (MQA)

原理: 多个 Query 头共享一组 K、V

# 标准 Multi-Head
Q: [h, n, d_k]  # h 个头
K: [h, n, d_k]  # h 个头
V: [h, n, d_v]  # h 个头

# Multi-Query
Q: [h, n, d_k]  # h 个头
K: [1, n, d_k]  # 共享 1 组
V: [1, n, d_v]  # 共享 1 组

优势:

  • 大幅减少 KV Cache 大小(从 h 降到 1)
  • 推理速度提升 30-50%
  • 几乎不损失精度

使用模型: Falcon, PaLM

4. Grouped-Query Attention (GQA)

原理: 将 Q 头分组,每组共享 K、V

# 例如:8 个 Q 头,分成 2 组
Group 1: Q_head 0-3 共享 K_group1, V_group1
Group 2: Q_head 4-7 共享 K_group2, V_group2

优势: 平衡 MQA 和 MHA 的优缺点

使用模型: LLaMA-2 (70B), Qwen, Mistral


🔥 FlashAttention

问题背景

标准 Attention 的瓶颈不是计算,而是内存访问

GPU 内存层次:
HBM (高带宽内存): 大但慢
SRAM (片上内存): 小但快

标准实现:
1. 计算 S = QK^T (写入 HBM)
2. 计算 P = softmax(S) (读写 HBM)
3. 计算 O = PV (读写 HBM)

问题: 大量 HBM 读写操作

FlashAttention 核心思想

IO-aware: 优化内存访问模式,而不是计算本身

# 核心: 分块计算 (Tiling)
for 每个 Q 的块:
    for 每个 K,V 的块:
        在 SRAM 中计算局部注意力
        累加结果,更新统计量

# 关键: 从不将 n×n 的注意力矩阵写入 HBM

优势

指标标准 AttentionFlashAttention
内存O(n²)O(n)
HBM 访问O(n²d + n²)O(n²d²/M)
计算相同相同(精确)

注意: FlashAttention 不是近似,数学结果完全相同!

FlashAttention 2/3 改进

FlashAttention 2:

  • 更好的并行度(跨序列长度维度)
  • 减少非矩阵乘法运算
  • 速度提升 2x

FlashAttention 3 (H100):

  • 利用 H100 的异步特性
  • FP8 支持
  • 进一步优化流水线

🔥 Sparse Attention 稀疏注意力

动机

全注意力 O(n²) 对于长序列太昂贵,但很多注意力权重接近 0。

常见稀疏模式

1. 局部注意力 (Local/Sliding Window)

每个位置只关注窗口内的 token
窗口大小 w,复杂度 O(nw)

|-------|
  |-------|
    |-------|

使用模型: Longformer, Mistral

2. 全局注意力 (Global)

某些特殊 token 关注所有位置
如 [CLS] token

*---------*
|  local  |
|  local  |

3. 跨步注意力 (Strided)

关注固定间隔的位置
如每 8 个位置关注一个

x.......x.......x

4. Longformer 的组合

局部 + 全局 + 滑动窗口

全局 token (如 [CLS]) 关注所有
其他 token 关注局部窗口

🔥 注意力可视化与理解

注意力头的功能

研究发现不同头学习不同模式:

头类型功能示例
位置头关注相邻位置前一个/后一个词
语法头关注句法关系主谓、动宾
语义头关注语义相似同义词、上下位
分隔符头关注特殊标记[SEP], [CLS]

注意力的局限性

Q: 注意力机制有什么问题?

A:

  1. 二次复杂度: O(n²) 限制序列长度
  2. 注意力稀释: 长序列中注意力分散
  3. 位置偏差: 可能过度关注局部
  4. 不可解释: 难以理解注意力模式

🔥 位置编码与注意力的交互

RoPE 如何融入注意力

# 标准注意力
scores = Q @ K.T

# 应用 RoPE
Q_rotated = apply_rope(Q, pos_m)  # 位置 m
K_rotated = apply_rope(K, pos_n)  # 位置 n

scores = Q_rotated @ K_rotated.T
# 结果自然包含相对位置信息 (m-n)

为什么 RoPE 外推性好?

数学解释:

  • 旋转是正交变换,不改变向量范数
  • 相对位置信息通过角度差表示
  • 角度可以平滑外推

🎯 面试题精选

Q1: Self-Attention 和 CNN、RNN 的对比

A:

特性Self-AttentionCNNRNN
依赖范围全局局部(堆叠扩大)逐步传播
并行性
位置感知需位置编码天然感知天然感知
计算复杂度O(n²d)O(knd²)O(nd²)
长距离依赖直接建模需要多层梯度消失

Q2: 为什么 Attention 用点积而不是加法?

A:

  1. 计算效率: 点积可用矩阵乘法,GPU 加速
  2. 参数更少: 不需要额外参数
  3. 效果相当: 实验表明两者效果相似

Q3: Multi-Head Attention 的 head 数量如何选择?

A:

  1. 经验法则: 保证 d_k = d_model / h 在 32-128 之间
  2. 常见配置: d_model=512 用 8 头,d_model=1024 用 16 头
  3. 太少: 表达能力不足
  4. 太多: 计算开销大,可能过拟合

Q4: Attention 权重能否用作可解释性?

A: 部分可以,但需谨慎:

  1. 有用场景: 检查是否关注关键词
  2. 局限性: 注意力权重 ≠ 重要性
  3. 替代方案: 梯度方法、LIME、SHAP

📚 延伸阅读


上一章: 01-Transformer架构详解 下一章: 03-主流大模型对比