注意力机制详解
4 分钟阅读
本文章属于专栏:AI知识
注意力机制详解
🔥 注意力机制的本质
注意力机制的核心思想:让模型学会"关注"输入的不同部分。
Query: "我想关注什么?"
Key: "每个位置能提供什么?"
Value: "每个位置的实际内容"
🔥 Self-Attention 深入理解
直观理解
想象一个句子:"The cat sat on the mat because it was tired"
当处理 "it" 这个词时:
- Query: "it" 需要找到它的指代对象
- Key: 每个词的"标签"(名词、动词等)
- Value: 每个词的实际含义
注意力权重会显示 "it" 主要关注 "cat"(权重 0.8),其他词权重较低。
数学本质
# 标准 Self-Attention
Attention(Q, K, V) = softmax(QK^T / √d_k) V
# QK^T: 计算每对 token 的相似度
# softmax: 归一化为概率分布
# × V: 加权求和
本质上是一个加权平均,权重由 Q 和 K 的相似度决定。
🔥 多种注意力变体
1. Causal Attention (因果注意力)
用途: 自回归模型(GPT)
# 下三角掩码
mask = [[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
# 每个位置只能看到左边(包括自己)
作用: 防止"看到未来",保证自回归特性
2. Cross-Attention (交叉注意力)
用途: Encoder-Decoder 连接
# Q: 来自 Decoder
# K, V: 来自 Encoder
CrossAttention(Q_dec, K_enc, V_enc)
应用场景:
- 机器翻译:理解源语言,生成目标语言
- 多模态:文本查询,视觉特征
3. Multi-Query Attention (MQA)
原理: 多个 Query 头共享一组 K、V
# 标准 Multi-Head
Q: [h, n, d_k] # h 个头
K: [h, n, d_k] # h 个头
V: [h, n, d_v] # h 个头
# Multi-Query
Q: [h, n, d_k] # h 个头
K: [1, n, d_k] # 共享 1 组
V: [1, n, d_v] # 共享 1 组
优势:
- 大幅减少 KV Cache 大小(从 h 降到 1)
- 推理速度提升 30-50%
- 几乎不损失精度
使用模型: Falcon, PaLM
4. Grouped-Query Attention (GQA)
原理: 将 Q 头分组,每组共享 K、V
# 例如:8 个 Q 头,分成 2 组
Group 1: Q_head 0-3 共享 K_group1, V_group1
Group 2: Q_head 4-7 共享 K_group2, V_group2
优势: 平衡 MQA 和 MHA 的优缺点
使用模型: LLaMA-2 (70B), Qwen, Mistral
🔥 FlashAttention
问题背景
标准 Attention 的瓶颈不是计算,而是内存访问:
GPU 内存层次:
HBM (高带宽内存): 大但慢
SRAM (片上内存): 小但快
标准实现:
1. 计算 S = QK^T (写入 HBM)
2. 计算 P = softmax(S) (读写 HBM)
3. 计算 O = PV (读写 HBM)
问题: 大量 HBM 读写操作
FlashAttention 核心思想
IO-aware: 优化内存访问模式,而不是计算本身
# 核心: 分块计算 (Tiling)
for 每个 Q 的块:
for 每个 K,V 的块:
在 SRAM 中计算局部注意力
累加结果,更新统计量
# 关键: 从不将 n×n 的注意力矩阵写入 HBM
优势
| 指标 | 标准 Attention | FlashAttention |
|---|---|---|
| 内存 | O(n²) | O(n) |
| HBM 访问 | O(n²d + n²) | O(n²d²/M) |
| 计算 | 相同 | 相同(精确) |
注意: FlashAttention 不是近似,数学结果完全相同!
FlashAttention 2/3 改进
FlashAttention 2:
- 更好的并行度(跨序列长度维度)
- 减少非矩阵乘法运算
- 速度提升 2x
FlashAttention 3 (H100):
- 利用 H100 的异步特性
- FP8 支持
- 进一步优化流水线
🔥 Sparse Attention 稀疏注意力
动机
全注意力 O(n²) 对于长序列太昂贵,但很多注意力权重接近 0。
常见稀疏模式
1. 局部注意力 (Local/Sliding Window)
每个位置只关注窗口内的 token
窗口大小 w,复杂度 O(nw)
|-------|
|-------|
|-------|
使用模型: Longformer, Mistral
2. 全局注意力 (Global)
某些特殊 token 关注所有位置
如 [CLS] token
*---------*
| local |
| local |
3. 跨步注意力 (Strided)
关注固定间隔的位置
如每 8 个位置关注一个
x.......x.......x
4. Longformer 的组合
局部 + 全局 + 滑动窗口
全局 token (如 [CLS]) 关注所有
其他 token 关注局部窗口
🔥 注意力可视化与理解
注意力头的功能
研究发现不同头学习不同模式:
| 头类型 | 功能 | 示例 |
|---|---|---|
| 位置头 | 关注相邻位置 | 前一个/后一个词 |
| 语法头 | 关注句法关系 | 主谓、动宾 |
| 语义头 | 关注语义相似 | 同义词、上下位 |
| 分隔符头 | 关注特殊标记 | [SEP], [CLS] |
注意力的局限性
Q: 注意力机制有什么问题?
A:
- 二次复杂度: O(n²) 限制序列长度
- 注意力稀释: 长序列中注意力分散
- 位置偏差: 可能过度关注局部
- 不可解释: 难以理解注意力模式
🔥 位置编码与注意力的交互
RoPE 如何融入注意力
# 标准注意力
scores = Q @ K.T
# 应用 RoPE
Q_rotated = apply_rope(Q, pos_m) # 位置 m
K_rotated = apply_rope(K, pos_n) # 位置 n
scores = Q_rotated @ K_rotated.T
# 结果自然包含相对位置信息 (m-n)
为什么 RoPE 外推性好?
数学解释:
- 旋转是正交变换,不改变向量范数
- 相对位置信息通过角度差表示
- 角度可以平滑外推
🎯 面试题精选
Q1: Self-Attention 和 CNN、RNN 的对比
A:
| 特性 | Self-Attention | CNN | RNN |
|---|---|---|---|
| 依赖范围 | 全局 | 局部(堆叠扩大) | 逐步传播 |
| 并行性 | 高 | 高 | 低 |
| 位置感知 | 需位置编码 | 天然感知 | 天然感知 |
| 计算复杂度 | O(n²d) | O(knd²) | O(nd²) |
| 长距离依赖 | 直接建模 | 需要多层 | 梯度消失 |
Q2: 为什么 Attention 用点积而不是加法?
A:
- 计算效率: 点积可用矩阵乘法,GPU 加速
- 参数更少: 不需要额外参数
- 效果相当: 实验表明两者效果相似
Q3: Multi-Head Attention 的 head 数量如何选择?
A:
- 经验法则: 保证 d_k = d_model / h 在 32-128 之间
- 常见配置: d_model=512 用 8 头,d_model=1024 用 16 头
- 太少: 表达能力不足
- 太多: 计算开销大,可能过拟合
Q4: Attention 权重能否用作可解释性?
A: 部分可以,但需谨慎:
- 有用场景: 检查是否关注关键词
- 局限性: 注意力权重 ≠ 重要性
- 替代方案: 梯度方法、LIME、SHAP
📚 延伸阅读
上一章: 01-Transformer架构详解 下一章: 03-主流大模型对比