Transformer 架构原理
理解 Transformer 架构是深入学习预训练模型的基础。本章将详细讲解 Transformer 的核心原理,包括注意力机制、位置编码、多头注意力等关键概念。
Transformer 的诞生背景
在 Transformer 出现之前,序列建模任务主要依赖循环神经网络(RNN)和长短期记忆网络(LSTM)。这些模型存在两个根本性的问题:
顺序计算的瓶颈:RNN 必须按顺序处理输入,无法充分利用并行计算能力。对于长度为 的序列,完成一次前向传播需要 个时间步,且任意位置的信息需要 步才能传播到序列另一端。
长距离依赖的困难:虽然 LSTM 通过门控机制缓解了梯度消失问题,但在处理很长的序列时,信息仍然会逐渐衰减。模型难以学习到序列开头和结尾之间的关系。
2017 年,Google 团队在论文《Attention Is All You Need》中提出了 Transformer 架构。其核心创新在于:完全抛弃了循环和卷积结构,仅依赖注意力机制来建模序列中任意两个位置之间的依赖关系。这使得:
- 每个位置可以同时与其他所有位置直接交互,路径长度降为
- 整个序列可以并行处理,训练速度大幅提升
整体架构
Transformer 采用编码器-解码器(Encoder-Decoder)结构,原始论文中的架构如下:
┌─────────────────────────────────────────────────────────────────┐
│ Transformer 架构 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 输入嵌入 编码器 解码器 │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌───────┐ ┌─────────┐ ┌─────────┐ │
│ │ Token │ │ Self │ │ Masked │ │
│ │Embedding│────────▶│Attention│────────▶│Self-Attn│ │
│ └───────┘ │ × N层 │ │ × N层 │ │
│ │ └─────────┘ └─────────┘ │
│ ▼ │ │ │
│ ┌───────┐ ▼ ▼ │
│ │位置编码│ ┌─────────┐ ┌─────────┐ │
│ └───────┘ │ FFN │ │ Cross │ │
│ │ × N层 │ │Attention│ │
│ └─────────┘ └─────────┘ │
│ │ │ │
│ ▼ ▼ │
│ 编码器输出 ┌─────────┐ │
│ │ FFN │ │
│ │ × N层 │ │
│ └─────────┘ │
│ │ │
│ ▼ │
│ 线性层 + Softmax │
│ │ │
│ ▼ │
│ 输出概率 │
│ │
└─────────────────────────────────────────────────────────────────┘
编码器接收输入序列,将其转换为连续表示;解码器基于编码器的输出,自回归地生成目标序列。
自注意力机制(Self-Attention)
自注意力是 Transformer 最核心的组件。它允许序列中的每个位置直接关注序列中的所有其他位置,从而捕捉长距离依赖。
注意力的本质
注意力机制可以理解为一种"查询-键-值"(Query-Key-Value)检索系统:
- Query(查询):当前需要关注其他位置的信息
- Key(键):用于与查询匹配,决定关注程度
- Value(值):实际要提取的内容
这个思想来源于信息检索系统。当你在搜索引擎输入查询时,系统将你的查询与文档库中的键(关键词)匹配,返回最相关的值(文档内容)。注意力机制本质上是在学习一种软性的检索,输出是所有值的加权和,权重由查询和键的相似度决定。
缩放点积注意力(Scaled Dot-Product Attention)
Transformer 使用缩放点积注意力,其计算公式为:
其中:
- 是查询矩阵
- 是键矩阵
- 是值矩阵
- 是键的维度
- 是序列长度
计算过程分解如下:
第一步:计算注意力分数
这会得到一个 的矩阵,每个元素 表示位置 对位置 的原始关注度(未归一化的相似度)。
第二步:缩放
为什么需要缩放?当 较大时,点积结果的方差也会很大。假设 和 的元素是均值为 0、方差为 1 的独立随机变量,则点积的均值为 0、方差为 。过大的方差会导致 softmax 的输入分布极端化,进入梯度极小的饱和区。除以 可以将方差归一化到 1,使梯度更稳定。
第三步:归一化
对每一行进行 softmax,将原始分数转换为概率分布。每个位置对所有位置的注意力权重之和为 1。
第四步:加权求和
用注意力权重对值进行加权,得到每个位置的输出表示。
代码实现
import torch
import torch.nn as nn
import math
class ScaledDotProductAttention(nn.Module):
"""缩放点积注意力"""
def __init__(self, d_k):
super().__init__()
self.scale = math.sqrt(d_k)
def forward(self, Q, K, V, mask=None):
"""
参数:
Q: [batch_size, num_heads, seq_len, d_k]
K: [batch_size, num_heads, seq_len, d_k]
V: [batch_size, num_heads, seq_len, d_v]
mask: 可选的掩码,用于屏蔽某些位置
"""
# 计算注意力分数
# [batch_size, num_heads, seq_len, seq_len]
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
# 应用掩码(如需要)
if mask is not None:
# 将被掩码的位置的分数设为负无穷
scores = scores.masked_fill(mask == 0, float('-inf'))
# softmax 归一化
attention_weights = torch.softmax(scores, dim=-1)
# 加权求和
# [batch_size, num_heads, seq_len, d_v]
output = torch.matmul(attention_weights, V)
return output, attention_weights
# 使用示例
d_k = 64
batch_size = 2
seq_len = 10
num_heads = 8
attention = ScaledDotProductAttention(d_k)
Q = torch.randn(batch_size, num_heads, seq_len, d_k)
K = torch.randn(batch_size, num_heads, seq_len, d_k)
V = torch.randn(batch_size, num_heads, seq_len, d_k)
output, weights = attention(Q, K, V)
print(f"输出形状: {output.shape}") # [2, 8, 10, 64]
print(f"注意力权重形状: {weights.shape}") # [2, 8, 10, 10]
自注意力的直观理解
自注意力让每个位置的 token 能够"看到"整个序列,并学习如何从其他 token 获取信息。考虑以下句子:
"The animal didn't cross the street because it was too tired"
当处理 "it" 这个词时,自注意力机制需要确定 "it" 指代的是什么。通过计算 "it" 与其他词的注意力权重,模型可以发现 "it" 与 "animal" 的关联最强,从而正确理解指代关系。
不同的注意力头可能会学习到不同类型的依赖关系:有的关注语法结构,有的关注语义关联,有的关注长距离依赖。
多头注意力(Multi-Head Attention)
单头注意力将所有信息压缩到单一的表示空间中,可能丢失信息。多头注意力通过并行运行多个独立的注意力头,让模型从多个表示子空间同时学习不同类型的依赖关系。
计算公式
其中每个头的计算为:
投影矩阵 、、 是每个头独有的参数, 是输出的投影矩阵。
原论文中,,使用 个注意力头,每个头的维度 。这样设计使得多头注意力的总计算量与单头注意力相近。
代码实现
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
"""多头注意力机制"""
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Q, K, V 的线性投影
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# 输出投影
self.W_o = nn.Linear(d_model, d_model)
self.scale = math.sqrt(self.d_k)
def forward(self, query, key, value, mask=None):
"""
参数:
query: [batch_size, seq_len, d_model]
key: [batch_size, seq_len, d_model]
value: [batch_size, seq_len, d_model]
mask: 可选的掩码
"""
batch_size = query.size(0)
# 线性投影并分割为多头
# [batch_size, seq_len, d_model] -> [batch_size, seq_len, num_heads, d_k]
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 计算注意力分数
# [batch_size, num_heads, seq_len, seq_len]
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
# 应用掩码
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# softmax 归一化
attention_weights = torch.softmax(scores, dim=-1)
# 加权求和
# [batch_size, num_heads, seq_len, d_k]
context = torch.matmul(attention_weights, V)
# 拼接多头
# [batch_size, num_heads, seq_len, d_k] -> [batch_size, seq_len, d_model]
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 输出投影
output = self.W_o(context)
return output, attention_weights
# 使用示例
d_model = 512
num_heads = 8
batch_size = 2
seq_len = 10
mha = MultiHeadAttention(d_model, num_heads)
x = torch.randn(batch_size, seq_len, d_model)
# 自注意力:Q, K, V 都是同一个输入
output, weights = mha(x, x, x)
print(f"输出形状: {output.shape}") # [2, 10, 512]
多头注意力的优势
多头注意力让每个头专注于不同类型的依赖关系。例如:
- 头 1:可能学习关注相邻位置的局部依赖(如 "new" 和 "york")
- 头 2:可能学习关注长距离的语义依赖(如 "it" 和指代的实体)
- 头 3:可能学习关注句法结构(如主语和谓语的关系)
这些不同视角的表示最后通过输出投影融合,形成更丰富的语义表示。
位置编码(Positional Encoding)
Transformer 完全基于注意力机制,本身不包含任何循环或卷积结构。这意味着模型对输入序列的顺序是不敏感的——如果你打乱输入 token 的顺序,除了位置编码外,模型的输出不会发生本质变化。
为了让模型感知序列中 token 的位置信息,Transformer 在输入嵌入中加入位置编码。原论文使用正弦和余弦函数生成固定的位置编码。
正弦位置编码公式
对于位置 和维度 :
其中:
- 是 token 在序列中的位置(从 0 开始)
- 是嵌入维度的索引
- 是模型的嵌入维度
这个设计的精妙之处在于:对于任意固定偏移量 , 可以表示为 的线性函数。这意味着模型可以通过学习相对位置关系来泛化到训练时未见过的序列长度。
代码实现
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
"""正弦位置编码"""
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# 创建位置编码矩阵
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# 计算分母项
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
# 偶数维度使用 sin,奇数维度使用 cos
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# 添加 batch 维度: [1, max_len, d_model]
pe = pe.unsqueeze(0)
# 注册为缓冲区(不参与梯度更新)
self.register_buffer('pe', pe)
def forward(self, x):
"""
参数:
x: [batch_size, seq_len, d_model]
"""
# 将位置编码加到输入嵌入上
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
# 使用示例
d_model = 512
max_len = 100
pos_encoder = PositionalEncoding(d_model, max_len)
# 假设输入嵌入
batch_size = 2
seq_len = 20
x = torch.randn(batch_size, seq_len, d_model)
output = pos_encoder(x)
print(f"输出形状: {output.shape}") # [2, 20, 512]
其他位置编码方式
除了正弦位置编码,还有多种其他方案:
可学习位置编码:将位置编码作为可训练参数。BERT、GPT 等模型采用这种方式。优点是可以根据任务学习最优的位置表示,缺点是无法泛化到训练时未见过的序列长度。
旋转位置编码(RoPE):LLaMA、Mistral 等现代大语言模型采用。通过将位置信息编码为旋转矩阵,使模型能够自然地学习相对位置关系,且具有良好的外推能力。
相对位置编码:直接编码两个位置之间的相对距离,而非绝对位置。这种方法对序列长度的泛化能力更强。
前馈神经网络(Feed-Forward Network)
每个 Transformer 层除了注意力子层外,还包含一个前馈神经网络(FFN)。这是一个位置独立的两层全连接网络,对每个位置独立应用。
计算公式
其中 ,。原论文中 ,是 的 4 倍。
FFN 的作用是为模型增加非线性变换能力。注意力机制本质上是线性操作(加权求和),FFN 引入了必要的非线性,使模型能够学习更复杂的函数。
代码实现
import torch.nn as nn
class PositionWiseFFN(nn.Module):
"""位置独立的前馈神经网络"""
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
self.activation = nn.ReLU()
def forward(self, x):
"""
参数:
x: [batch_size, seq_len, d_model]
"""
# [batch_size, seq_len, d_ff]
x = self.fc1(x)
x = self.activation(x)
x = self.dropout(x)
# [batch_size, seq_len, d_model]
x = self.fc2(x)
return x
# 使用示例
d_model = 512
d_ff = 2048
ffn = PositionWiseFFN(d_model, d_ff)
x = torch.randn(2, 10, d_model)
output = ffn(x)
print(f"输出形状: {output.shape}") # [2, 10, 512]
现代变体中,激活函数常用 GELU 或 Swish 替代 ReLU,它们在平滑性和梯度传播上有一定优势。
层归一化与残差连接
Transformer 的每个子层(注意力和 FFN)都采用残差连接后接层归一化的结构:
残差连接的作用
残差连接让梯度可以直接流过恒等映射,缓解了深层网络的梯度消失问题。在 Transformer 中,每个子层的输入可以直接"跳过"子层到达输出,这有助于训练稳定性。
层归一化的作用
层归一化对每个样本的特征维度进行归一化,使输出的均值为 0、方差为 1。与批归一化不同,层归一化不依赖于 batch 统计量,因此对 batch 大小不敏感,更适合处理变长序列。
import torch.nn as nn
class LayerNorm(nn.Module):
"""层归一化"""
def __init__(self, d_model, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
# 在最后一个维度上计算均值和方差
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
# 归一化
x_norm = (x - mean) / (std + self.eps)
# 缩放和平移
return self.gamma * x_norm + self.beta
Pre-LN vs Post-LN
原论文使用的是 Post-LN(归一化在残差连接之后),但后续研究发现 Pre-LN(归一化在子层之前)对深层网络的训练更稳定:
Post-LN: Output = LayerNorm(x + Sublayer(x))
Pre-LN: Output = x + Sublayer(LayerNorm(x))
现代大语言模型(如 GPT-2/3、LLaMA)普遍采用 Pre-LN 结构。
编码器结构
编码器由 个相同的层堆叠而成(原论文 )。每层包含两个子层:
- 多头自注意力层
- 前馈神经网络层
每个子层都使用残差连接和层归一化。
class EncoderLayer(nn.Module):
"""Transformer 编码器层"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# 多头自注意力
self.self_attn = MultiHeadAttention(d_model, num_heads)
# 前馈网络
self.ffn = PositionWiseFFN(d_model, d_ff, dropout)
# 层归一化
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# Dropout
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 自注意力子层 + 残差连接 + 层归一化
attn_output, _ = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# FFN 子层 + 残差连接 + 层归一化
ffn_output = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_output))
return x
解码器结构
解码器也由 个相同的层堆叠而成。每层包含三个子层:
- 掩码自注意力层:使用因果掩码确保每个位置只能关注它之前的位置
- 编码器-解码器注意力层:Query 来自解码器,Key 和 Value 来自编码器输出
- 前馈神经网络层
因果掩码是解码器的关键设计。在生成任务中,解码器需要自回归地逐个生成 token,不能"偷看"未来的信息。
class DecoderLayer(nn.Module):
"""Transformer 解码器层"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# 掩码自注意力
self.self_attn = MultiHeadAttention(d_model, num_heads)
# 编码器-解码器注意力
self.cross_attn = MultiHeadAttention(d_model, num_heads)
# 前馈网络
self.ffn = PositionWiseFFN(d_model, d_ff, dropout)
# 层归一化
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
# Dropout
self.dropout = nn.Dropout(dropout)
def forward(self, x, encoder_output, self_mask=None, cross_mask=None):
# 掩码自注意力
attn_output, _ = self.self_attn(x, x, x, self_mask)
x = self.norm1(x + self.dropout(attn_output))
# 编码器-解码器注意力
attn_output, _ = self.cross_attn(x, encoder_output, encoder_output, cross_mask)
x = self.norm2(x + self.dropout(attn_output))
# FFN
ffn_output = self.ffn(x)
x = self.norm3(x + self.dropout(ffn_output))
return x
def generate_causal_mask(seq_len):
"""生成因果掩码"""
# 上三角矩阵(不包括对角线)为 0,其余为 1
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
三种注意力应用
Transformer 中,多头注意力以三种不同的方式使用:
1. 编码器自注意力
在编码器中,Q、K、V 都来自同一个输入序列。每个位置可以关注输入序列的所有位置,实现双向上下文理解。BERT 等仅编码器模型就是基于这种设计。
2. 解码器掩码自注意力
在解码器中,自注意力被因果掩码限制:每个位置只能关注它之前的位置。这保证了解码的自回归性质——生成第 个 token 时,模型不能看到第 及之后的 token。GPT 等仅解码器模型基于这种设计。
3. 编码器-解码器注意力
这是连接编码器和解码器的桥梁。Query 来自解码器,Key 和 Value 来自编码器的输出。这允许解码器的每个位置关注输入序列的所有位置,实现跨序列的信息流动。这在翻译、摘要等序列到序列任务中至关重要。
架构变体
基于原始 Transformer 的编码器-解码器结构,衍生出了多种架构变体:
Encoder-only(仅编码器)
代表模型:BERT、RoBERTa、ALBERT
特点:
- 只有编码器部分
- 双向注意力,可以看到整个序列
- 适合理解任务:分类、标注、抽取式问答
Decoder-only(仅解码器)
代表模型:GPT 系列、LLaMA、Mistral
特点:
- 只有解码器部分(移除编码器-解码器注意力)
- 单向因果注意力,自回归生成
- 适合生成任务:文本生成、对话、代码补全
Encoder-Decoder(编码器-解码器)
代表模型:T5、BART、Marian
特点:
- 完整的 Transformer 结构
- 编码器处理输入,解码器自回归生成输出
- 适合序列转换任务:翻译、摘要、文本改写
计算复杂度分析
比较不同层类型的计算复杂度:
| 层类型 | 每层复杂度 | 顺序操作数 | 最大路径长度 |
|---|---|---|---|
| 自注意力 | |||
| 循环层 | |||
| 卷积层 |
其中 是序列长度, 是表示维度, 是卷积核大小。
自注意力的优势在于:
- 序列操作数最少,完全可并行
- 任意两个位置之间的路径长度为常数
自注意力的劣势在于:
- 对序列长度的复杂度是平方级的,处理长序列时内存和计算开销大
这也是后续涌现出众多长序列优化方法(如稀疏注意力、线性注意力、FlashAttention)的原因。
小结
Transformer 架构的核心要点:
- 自注意力机制:让序列中每个位置直接与所有其他位置交互,解决长距离依赖问题
- 多头注意力:从多个表示子空间学习不同类型的依赖关系
- 位置编码:为模型提供位置信息,弥补注意力机制对顺序不敏感的特性
- 残差连接与层归一化:稳定深层网络的训练
- 编码器-解码器结构:编码器提取输入表示,解码器自回归生成输出
理解这些核心概念后,学习 BERT、GPT、LLaMA 等具体模型就变得水到渠成。这些模型本质上都是 Transformer 架构的不同变体,在注意力机制、位置编码、归一化位置等细节上各有创新。
参考资源
- Attention Is All You Need - 原始 Transformer 论文
- The Illustrated Transformer - Transformer 可视化讲解
- Annotated Transformer - 带详细注释的 PyTorch 实现