PyTorch 学习笔记 (10): 注意力机制 (Attention)
2025-12-28·12 min read
#PyTorch#Deep Learning#Attention#Transformer
注意力机制是现代深度学习的核心技术,广泛应用于 Transformer、BERT、GPT 等模型。
核心思想
让模型学会关注重点,动态地为输入的不同部分分配不同的权重。
主要类型:
- Additive Attention (Bahdanau Attention)
- Dot-Product Attention (Luong Attention)
- Scaled Dot-Product Attention (Transformer)
- Self-Attention (自注意力)
- Multi-Head Attention (多头注意力)
Scaled Dot-Product Attention
这是 Transformer 的核心组件:
text
Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ScaledDotProductAttention(nn.Module):
"""缩放点积注意力"""
def __init__(self, d_k):
super().__init__()
self.scale = math.sqrt(d_k)
def forward(self, query, key, value, mask=None):
"""
Args:
query: (batch, heads, seq_q, d_k)
key: (batch, heads, seq_k, d_k)
value: (batch, heads, seq_v, d_v)
mask: 可选的mask
Returns:
output: (batch, heads, seq_q, d_v)
attention: (batch, heads, seq_q, seq_k)
"""
# 计算注意力分数
scores = torch.matmul(query, key.transpose(-2, -1)) / self.scale
# 应用mask(如果提供)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# softmax归一化
attention = F.softmax(scores, dim=-1)
# 应用注意力到value
output = torch.matmul(attention, value)
return output, attention
为什么需要缩放?
- 当 d_k 很大时,点积值会很大
- 导致 softmax 输出接近 one-hot,梯度接近 0
- 除以 sqrt(d_k) 可以稳定训练
Self-Attention (自注意力)
序列中的每个元素都与序列中的所有其他元素计算注意力。Query、Key、Value 都来自同一个输入。
python
class SelfAttention(nn.Module):
"""自注意力机制"""
def __init__(self, embed_dim, heads=8):
super().__init__()
self.embed_dim = embed_dim
self.heads = heads
self.head_dim = embed_dim // heads
assert embed_dim % heads == 0, "embed_dim必须能被heads整除"
self.q_linear = nn.Linear(embed_dim, embed_dim)
self.k_linear = nn.Linear(embed_dim, embed_dim)
self.v_linear = nn.Linear(embed_dim, embed_dim)
self.attention = ScaledDotProductAttention(self.head_dim)
self.out_linear = nn.Linear(embed_dim, embed_dim)
def forward(self, x, mask=None):
"""
Args:
x: (batch, seq_len, embed_dim)
Returns:
out: (batch, seq_len, embed_dim)
attention: (batch, heads, seq_len, seq_len)
"""
batch_size = x.size(0)
# 线性变换
q = self.q_linear(x)
k = self.k_linear(x)
v = self.v_linear(x)
# 分割成多头
q = q.view(batch_size, -1, self.heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.heads, self.head_dim).transpose(1, 2)
# 计算注意力
out, attention = self.attention(q, k, v, mask)
# 合并多头
out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
# 最终线性变换
out = self.out_linear(out)
return out, attention
Multi-Head Attention (多头注意力)
将输入分割到多个"头"中,每个头独立计算注意力,允许模型同时关注不同位置的不同表示子空间。
python
class MultiHeadAttention(nn.Module):
"""多头注意力机制"""
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
assert embed_dim % num_heads == 0
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# Q, K, V 的线性投影
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
# 输出投影
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(self.head_dim)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 线性投影
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
# 分割多头
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# 计算缩放点积注意力
scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = F.softmax(scores, dim=-1)
attention = self.dropout(attention)
# 应用注意力到value
output = torch.matmul(attention, v)
# 合并多头
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
# 输出投影
output = self.out_proj(output)
return output, attention
位置编码
由于自注意力没有序列位置信息,需要添加位置编码:
python
class PositionalEncoding(nn.Module):
"""位置编码"""
def __init__(self, embed_dim, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
# 生成位置编码
pe = torch.zeros(max_len, embed_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
注意力机制对比
| 类型 | 计算方式 | 特点 |
|---|---|---|
| Additive | v^T * tanh(W_h * h + W_s * s) | 灵活但计算量大 |
| Dot-Product | h^T * s | 简单高效 |
| Scaled Dot-Product | Q @ K^T / sqrt(d_k) | Transformer核心 |
| Self-Attention | 同一输入的Q/K/V | 捕捉长距离依赖 |
| Multi-Head | 多头并行计算 | 捕捉多种依赖关系 |
维度变化
text
输入: (batch, seq_len, embed_dim)
↓
Q, K, V投影: (batch, seq_len, embed_dim)
↓
分割多头: (batch, num_heads, seq_len, head_dim)
↓
注意力计算: (batch, num_heads, seq_len, seq_len)
↓
应用注意力: (batch, num_heads, seq_len, head_dim)
↓
合并多头: (batch, seq_len, embed_dim)
↓
输出投影: (batch, seq_len, embed_dim)
选择建议
| 场景 | 推荐方法 |
|---|---|
| 计算资源有限 | Dot-Product |
| 需要灵活性 | Additive |
| 追求性能 | Multi-Head Scaled Dot-Product |
| 长序列 | Sparse Attention |
总结
注意力机制是现代 NLP 和 CV 的核心技术:
- 核心公式:
softmax(QK^T / sqrt(d_k)) @ V - 自注意力:Q、K、V 来自同一输入
- 多头注意力:并行计算多个注意力,捕捉不同依赖
- 位置编码:为注意力提供位置信息