跳到主要内容

注意力机制与 Transformer

注意力机制和 Transformer 架构是现代自然语言处理的基石。2017 年,Google 在论文《Attention Is All You Need》中提出的 Transformer 模型彻底改变了 NLP 领域,成为 BERT、GPT 等主流预训练模型的基础架构。

为什么需要注意力机制

在 Transformer 出现之前,处理序列数据主要依赖循环神经网络(RNN)和长短期记忆网络(LSTM)。这些模型存在几个关键问题:

顺序计算瓶颈:RNN 必须按顺序处理输入,无法并行计算,训练效率低。

长距离依赖问题:虽然 LSTM 通过门控机制缓解了梯度消失问题,但对于很长的序列,信息仍然会逐渐丢失。

信息压缩:Seq2Seq 模型将整个输入序列压缩为一个固定长度的向量,信息损失严重。

注意力机制通过让模型在处理每个位置时都能"看到"序列的所有其他位置,解决了这些问题。

注意力机制的核心思想

注意力机制的核心思想是:在处理序列的某个位置时,根据其与序列其他位置的相关性,对不同位置的信息赋予不同的权重。

想象你在阅读一篇文章时,读到某个词时,你的注意力会自动聚焦到与这个词相关的其他词上。注意力机制模拟的就是这种过程。

注意力计算过程

给定查询(Query)、键(Key)和值(Value),注意力机制的计算过程如下:

  1. 计算查询与每个键的相似度,得到注意力分数
  2. 对注意力分数进行归一化(通常使用 Softmax)
  3. 用归一化的注意力分数对值进行加权求和

数学表达式为:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

其中 dkd_k 是键向量的维度,除以 dk\sqrt{d_k} 是为了防止点积过大导致 Softmax 梯度过小。

注意力机制实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ScaledDotProductAttention(nn.Module):
"""缩放点积注意力"""

def __init__(self, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)

def forward(self, query, key, value, mask=None):
"""
query: (batch_size, num_heads, seq_len_q, d_k)
key: (batch_size, num_heads, seq_len_k, d_k)
value: (batch_size, num_heads, seq_len_v, d_v)
mask: 可选的掩码
"""
d_k = query.size(-1)

# 计算注意力分数:(batch_size, num_heads, seq_len_q, seq_len_k)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

# 应用掩码(将不需要的位置设为负无穷)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))

# Softmax 归一化
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)

# 加权求和
output = torch.matmul(attention_weights, value)

return output, attention_weights

# 使用示例
attention = ScaledDotProductAttention()

# 模拟输入
batch_size = 2
num_heads = 8
seq_len = 10
d_k = 64

query = torch.randn(batch_size, num_heads, seq_len, d_k)
key = torch.randn(batch_size, num_heads, seq_len, d_k)
value = torch.randn(batch_size, num_heads, seq_len, d_k)

output, weights = attention(query, key, value)
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")

多头注意力

单头注意力只能学习一种相关性模式,多头注意力通过并行运行多个注意力函数,让模型能够同时关注不同位置的不同表示子空间。

多头注意力原理

多头注意力将查询、键、值分别投影到多个子空间,在每个子空间独立计算注意力,然后将结果拼接并投影回原始维度。

class MultiHeadAttention(nn.Module):
"""多头注意力机制"""

def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"

self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每个头的维度

# 线性投影层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

self.attention = ScaledDotProductAttention(dropout)

def forward(self, query, key, value, mask=None):
batch_size = query.size(0)

# 线性投影并分头
# (batch_size, seq_len, d_model) -> (batch_size, num_heads, seq_len, d_k)
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

# 计算注意力
output, attention_weights = self.attention(Q, K, V, mask)

# 合并多头
# (batch_size, num_heads, seq_len, d_k) -> (batch_size, seq_len, d_model)
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

# 最终投影
output = self.W_o(output)

return output, attention_weights

# 使用示例
d_model = 512
num_heads = 8

mha = MultiHeadAttention(d_model, num_heads)

# 模拟输入
seq_len = 20
x = torch.randn(batch_size, seq_len, d_model)

output, weights = mha(x, x, x) # 自注意力:Q=K=V
print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")

自注意力

自注意力(Self-Attention)是注意力机制的一种特殊形式,其中查询、键、值都来自同一个输入序列。这使得序列中的每个位置都能与其他所有位置进行交互。

自注意力的直观理解

考虑句子 "The animal didn't cross the street because it was too tired"。当模型处理 "it" 这个词时,自注意力机制会让模型关注到 "animal",因为 "it" 在这里指代 "animal"。

def self_attention_example():
"""自注意力示例:展示单词之间的关联"""
import torch.nn.functional as F

# 模拟词嵌入(简化版)
words = ["The", "animal", "didn't", "cross", "the", "street", "because", "it", "was", "tired"]
vocab_size = len(words)
d_model = 8

# 随机初始化词向量
torch.manual_seed(42)
embeddings = torch.randn(vocab_size, d_model)

# 计算自注意力权重
scores = torch.matmul(embeddings, embeddings.T) / math.sqrt(d_model)
attention_weights = F.softmax(scores, dim=-1)

# 查看 "it"(索引 7)对其他词的注意力
it_attention = attention_weights[7]

print("单词 'it' 对其他词的注意力权重:")
for i, (word, weight) in enumerate(zip(words, it_attention)):
bar = "█" * int(weight * 100)
print(f" {word:<10} {weight:.4f} {bar}")

self_attention_example()

Transformer 架构

Transformer 由编码器和解码器两部分组成,每部分都由多个相同的层堆叠而成。

整体架构

输入 → 词嵌入 + 位置编码 
→ [编码器层 × N]
→ [解码器层 × N]
→ 线性层 + Softmax
→ 输出概率

编码器

编码器由多个相同的层组成,每层包含两个子层:

  1. 多头自注意力层:序列内部的自注意力计算
  2. 前馈神经网络层:对每个位置独立应用的全连接网络

每个子层都使用残差连接和层归一化。

class PositionwiseFeedForward(nn.Module):
"""前馈神经网络"""

def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)

def forward(self, x):
# 两层全连接,中间有 ReLU 激活
return self.fc2(self.dropout(F.relu(self.fc1(x))))

class EncoderLayer(nn.Module):
"""编码器层"""

def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)

def forward(self, x, mask=None):
# 自注意力子层 + 残差连接 + 层归一化
attn_output, _ = self.self_attention(x, x, x, mask)
x = self.norm1(x + self.dropout1(attn_output))

# 前馈网络子层 + 残差连接 + 层归一化
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout2(ff_output))

return x

解码器

解码器也由多个相同的层组成,每层包含三个子层:

  1. 带掩码的多头自注意力层:防止解码器看到未来的信息
  2. 编码器-解码器注意力层:让解码器关注编码器的输出
  3. 前馈神经网络层
class DecoderLayer(nn.Module):
"""解码器层"""

def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
self.cross_attention = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)

self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)

self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)

def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
# 带掩码的自注意力
attn_output, _ = self.self_attention(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout1(attn_output))

# 编码器-解码器注意力
attn_output, _ = self.cross_attention(x, encoder_output, encoder_output, src_mask)
x = self.norm2(x + self.dropout2(attn_output))

# 前馈网络
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout3(ff_output))

return x

位置编码

由于 Transformer 没有循环结构,无法自然地捕捉位置信息,因此需要显式添加位置编码。

class PositionalEncoding(nn.Module):
"""位置编码"""

def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)

# 计算位置编码
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

# 使用正弦和余弦函数
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)

pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度

pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)

def forward(self, x):
# x: (batch_size, seq_len, d_model)
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)

# 使用示例
pos_encoding = PositionalEncoding(d_model=512, max_len=100)

# 模拟输入
x = torch.randn(2, 20, 512)
output = pos_encoding(x)
print(f"位置编码后形状: {output.shape}")

位置编码使用正弦和余弦函数的好处是:模型可以学习到相对位置信息,因为对于任意固定偏移量 kkPEpos+kPE_{pos+k} 可以表示为 PEposPE_{pos} 的线性函数。

完整的 Transformer 模型

class Transformer(nn.Module):
"""完整的 Transformer 模型"""

def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8,
num_encoder_layers=6, num_decoder_layers=6, d_ff=2048, dropout=0.1):
super().__init__()

# 词嵌入
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)

# 位置编码
self.pos_encoding = PositionalEncoding(d_model, dropout=dropout)

# 编码器层
self.encoder_layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_encoder_layers)
])

# 解码器层
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_decoder_layers)
])

# 输出层
self.fc_out = nn.Linear(d_model, tgt_vocab_size)

self.d_model = d_model

def make_src_mask(self, src):
"""创建源序列掩码"""
# 将填充位置设为 0,其他位置设为 1
return (src != 0).unsqueeze(1).unsqueeze(2)

def make_tgt_mask(self, tgt):
"""创建目标序列掩码(防止看到未来)"""
batch_size, seq_len = tgt.size()

# 填充掩码
pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)

# 因果掩码(下三角矩阵)
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)

# 组合掩码
return pad_mask & causal_mask.bool()

def encode(self, src, src_mask):
"""编码"""
x = self.src_embedding(src) * math.sqrt(self.d_model)
x = self.pos_encoding(x)

for layer in self.encoder_layers:
x = layer(x, src_mask)

return x

def decode(self, tgt, encoder_output, src_mask, tgt_mask):
"""解码"""
x = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
x = self.pos_encoding(x)

for layer in self.decoder_layers:
x = layer(x, encoder_output, src_mask, tgt_mask)

return x

def forward(self, src, tgt):
"""前向传播"""
src_mask = self.make_src_mask(src)
tgt_mask = self.make_tgt_mask(tgt)

encoder_output = self.encode(src, src_mask)
decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)

output = self.fc_out(decoder_output)
return output

# 创建模型
src_vocab_size = 10000
tgt_vocab_size = 8000

model = Transformer(src_vocab_size, tgt_vocab_size)
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")

Transformer 的关键创新

并行计算

与 RNN 不同,Transformer 可以并行处理整个序列,大大提高了训练效率。这使得在大规模数据上训练大模型成为可能。

长距离依赖

自注意力机制让序列中任意两个位置之间都可以直接交互,距离为 O(1)O(1),而 RNN 中距离为 O(n)O(n)。这使得模型能够更好地捕捉长距离依赖。

残差连接和层归一化

残差连接解决了深层网络的梯度消失问题,层归一化加速了训练收敛。

# 残差连接的效果示意
def residual_effect_demo():
"""展示残差连接的作用"""
import torch

x = torch.randn(1, 10)

# 无残差连接:梯度可能消失
y1 = torch.relu(torch.relu(torch.relu(x)))

# 有残差连接:梯度可以直接传播
y2 = x + torch.relu(x)
y2 = y2 + torch.relu(y2)
y2 = y2 + torch.relu(y2)

print("无残差连接的输出范围:", y1.min().item(), y1.max().item())
print("有残差连接的输出范围:", y2.min().item(), y2.max().item())

residual_effect_demo()

Transformer 变体

Encoder-only 模型(BERT 系列)

只使用编码器部分,适合理解类任务,如文本分类、命名实体识别、问答等。BERT 是典型代表。

class BERTEncoder(nn.Module):
"""BERT 风格的编码器"""

def __init__(self, vocab_size, d_model=768, num_heads=12, num_layers=12, d_ff=3072):
super().__init__()

self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model)

self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff)
for _ in range(num_layers)
])

self.norm = nn.LayerNorm(d_model)

def forward(self, x, mask=None):
x = self.embedding(x)
x = self.pos_encoding(x)

for layer in self.layers:
x = layer(x, mask)

return self.norm(x)

Decoder-only 模型(GPT 系列)

只使用解码器部分,适合生成类任务,如文本生成、对话等。GPT 是典型代表。

class GPTDecoder(nn.Module):
"""GPT 风格的解码器"""

def __init__(self, vocab_size, d_model=768, num_heads=12, num_layers=12, d_ff=3072):
super().__init__()

self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model)

self.layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff)
for _ in range(num_layers)
])

self.norm = nn.LayerNorm(d_model)
self.fc_out = nn.Linear(d_model, vocab_size)

def forward(self, x, mask=None):
x = self.embedding(x)
x = self.pos_encoding(x)

# 创建因果掩码
seq_len = x.size(1)
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)

for layer in self.layers:
# GPT 只有自注意力,不需要 cross_attention
x = layer.self_attn_layer(x, causal_mask)

return self.fc_out(self.norm(x))

Encoder-Decoder 模型(T5、BART)

完整使用编码器和解码器,适合序列到序列任务,如翻译、摘要等。

注意力机制的发展

稀疏注意力

标准注意力的复杂度是 O(n2)O(n^2),对于长序列计算成本很高。稀疏注意力通过限制每个位置只关注部分位置来降低复杂度。

线性注意力

通过核函数近似,将注意力复杂度从 O(n2)O(n^2) 降到 O(n)O(n)

Flash Attention

通过优化 GPU 内存访问模式,大幅提高注意力计算的效率。

实际应用示例

文本分类

from transformers import BertModel, BertTokenizer
import torch

class TextClassifier(nn.Module):
"""基于 BERT 的文本分类器"""

def __init__(self, model_name="bert-base-chinese", num_classes=2):
super().__init__()
self.bert = BertModel.from_pretrained(model_name)
self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)

def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
# 使用 [CLS] token 的表示进行分类
cls_output = outputs.last_hidden_state[:, 0, :]
logits = self.classifier(cls_output)
return logits

# 使用示例
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
model = TextClassifier(num_classes=2)

text = "这部电影非常好看"
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

with torch.no_grad():
logits = model(inputs["input_ids"], inputs["attention_mask"])
prediction = torch.argmax(logits, dim=1).item()

print(f"预测类别: {'正面' if prediction == 1 else '负面'}")

文本生成

from transformers import GPT2LMHeadModel, GPT2Tokenizer

class TextGenerator:
"""基于 GPT 的文本生成器"""

def __init__(self, model_name="gpt2"):
self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
self.model = GPT2LMHeadModel.from_pretrained(model_name)

def generate(self, prompt, max_length=100, temperature=0.7):
inputs = self.tokenizer.encode(prompt, return_tensors="pt")

outputs = self.model.generate(
inputs,
max_length=max_length,
temperature=temperature,
do_sample=True,
top_k=50,
top_p=0.95
)

return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

# 使用示例
# generator = TextGenerator()
# text = generator.generate("Once upon a time")
# print(text)

总结

注意力机制和 Transformer 架构是现代 NLP 的核心基础,本章介绍了:

  • 注意力机制:让模型动态关注输入的不同部分
  • 多头注意力:并行学习多种相关性模式
  • 自注意力:序列内部的交互机制
  • Transformer 架构:编码器和解码器的完整结构
  • 位置编码:为模型提供位置信息
  • Transformer 变体:BERT、GPT 等模型的设计原理

理解 Transformer 架构对于深入学习现代 NLP 技术至关重要。BERT、GPT、T5 等主流预训练模型都基于 Transformer,掌握其原理有助于更好地使用和改进这些模型。