显存优化技术
GPU 显存是训练和部署大模型的核心瓶颈。理解显存的组成和优化方法,对于高效利用硬件资源至关重要。本章将系统介绍显存优化核心技术,帮助你在有限的显存中训练更大的模型。
显存消耗分析
显存组成部分
训练过程中,GPU 显存主要消耗在以下几个方面:
┌─────────────────────────────────────────────────────────┐
│ GPU 显存组成 │
├─────────────────────────────────────────────────────────┤
│ 模型参数 (Parameters) │
│ - FP16/BF16 存储 │
│ - 与参数量成正比 │
├─────────────────────────────────────────────────────────┤
│ 梯度 (Gradients) │
│ - 与模型参数大小相同 │
│ - 反向传播时存储 │
├─────────────────────────────────────────────────────────┤
│ 优化器状态 (Optimizer States) │
│ - AdamW: 参数副本(FP32) + 动量 + 方差 │
│ - 约为参数大小的 6 倍 │
├─────────────────────────────────────────────────────────┤
│ 激活值 (Activations) │
│ - 前向传播中间结果 │
│ - 与批次大小、序列长度相关 │
├─────────────────────────────────────────────────────────┤
│ 临时缓存 (Temporary Buffers) │
│ - CUDA 核函数临时存储 │
│ - 通信缓冲区 │
└─────────────────────────────────────────────────────────┘
显存计算公式
对于参数量为 的模型,使用 AdamW 优化器和混合精度训练:
| 组件 | 显存占用(FP16 训练) | 说明 |
|---|---|---|
| 模型参数 | FP16 存储 | |
| 梯度 | FP16 存储 | |
| 优化器状态 | FP32 参数副本 + 动量 + 方差 | |
| 激活值 | 变化 | 与批次大小和序列长度相关 |
总计约 字节,意味着:
- 7B 模型:约 112 GB
- 13B 模型:约 208 GB
- 70B 模型:约 1.12 TB
实际案例:Llama-2-7B 显存分析
假设在 A100 80GB 上训练 Llama-2-7B:
| 组件 | 显存占用 | 比例 |
|---|---|---|
| 模型参数 | 14 GB | 12.5% |
| 梯度 | 14 GB | 12.5% |
| 优化器状态 | 84 GB | 75% |
| 激活值 | ~10-30 GB | 变化 |
| 总计 | 122-142 GB | - |
显然,单卡 A100 80GB 无法直接训练 7B 模型,需要使用显存优化技术。
混合精度训练
混合精度训练通过使用 FP16/BF16 进行计算,同时保留 FP32 精度的优化器状态,在减少显存的同时保证训练稳定性。
原理
┌─────────────────────────────────────────────────────────┐
│ 混合精度训练流程 │
├─────────────────────────────────────────────────────────┤
│ 前向传播:FP16 计算 │
│ ↓ │
│ 损失计算:FP32 精度(防止下溢) │
│ ↓ │
│ 损失缩放:放大损失值 │
│ ↓ │
│ 反向传播:FP16 计算梯度 │
│ ↓ │
│ 梯度反缩放:恢复原始梯度值 │
│ ↓ │
│ 参数更新:FP32 优化器状态 │
└─────────────────────────────────────────────────────────┘
PyTorch 自动混合精度
import torch
from torch.cuda.amp import autocast, GradScaler
model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = GradScaler() # 梯度缩放器
for batch in dataloader:
optimizer.zero_grad()
# 前向传播使用自动混合精度
with autocast():
output = model(batch)
loss = criterion(output, target)
# 反向传播
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
BF16 vs FP16
BF16(Brain Float 16)是另一种 16 位浮点格式,相比 FP16 有更大的动态范围:
| 特性 | FP16 | BF16 | FP32 |
|---|---|---|---|
| 位数 | 16 | 16 | 32 |
| 指数位 | 5 | 8 | 8 |
| 尾数位 | 10 | 7 | 23 |
| 动态范围 | 到 | 与 FP32 相同 | 到 |
| 精度 | 较高 | 较低 | 最高 |
BF16 的优势:
- 动态范围与 FP32 相同,避免数值溢出/下溢
- 不需要损失缩放
- 训练更稳定
使用 BF16:
# PyTorch 1.10+ 支持 BF16
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output = model(batch)
loss = criterion(output, target)
# BF16 不需要 GradScaler
loss.backward()
optimizer.step()
硬件要求:BF16 需要 Ampere 架构及以上的 GPU(A100、RTX 3090/4090、H100)。
梯度检查点
梯度检查点(Gradient Checkpointing)是一种以计算换内存的技术,通过在反向传播时重新计算中间激活值来减少显存占用。
原理
标准训练:
前向传播:计算并存储所有激活值
反向传播:读取存储的激活值计算梯度
显存占用:O(层数 × 批次大小 × 序列长度)
梯度检查点:
前向传播:只存储部分检查点的激活值
反向传播:从检查点重新计算需要的激活值
显存占用:O(√(层数) × 批次大小 × 序列长度)
PyTorch 使用方法
from torch.utils.checkpoint import checkpoint
class TransformerBlock(nn.Module):
def __init__(self, use_checkpoint=False):
super().__init__()
self.use_checkpoint = use_checkpoint
self.attention = MultiHeadAttention()
self.mlp = MLP()
def forward(self, x):
if self.use_checkpoint and self.training:
# 使用梯度检查点
x = checkpoint(self._forward, x, use_reentrant=False)
else:
x = self._forward(x)
return x
def _forward(self, x):
x = x + self.attention(x)
x = x + self.mlp(x)
return x
PyTorch 2.0+ 简化使用
# 设置整个模型使用梯度检查点
model.gradient_checkpointing_enable()
# 或在 Hugging Face 模型中
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
device_map="auto"
)
model.gradient_checkpointing_enable()
显存-计算权衡
| 配置 | 显存占用 | 计算开销 | 适用场景 |
|---|---|---|---|
| 无检查点 | 100% | 基准 | 显存充足 |
| 全检查点 | 30-40% | +33% | 显存紧张 |
| 选择性检查点 | 50-60% | +15% | 平衡选择 |
最佳实践:
- 对显存占用大的层(如 Attention)启用检查点
- 对计算密集的层(如 LayerNorm)不启用检查点
ZeRO 优化
ZeRO(Zero Redundancy Optimizer)是 DeepSpeed 提供的显存优化技术,通过分片消除数据并行中的冗余。
三阶段优化
┌─────────────────────────────────────────────────────────┐
│ ZeRO 优化阶段 │
├─────────────────────────────────────────────────────────┤
│ ZeRO-1:分片优化器状态 │
│ - 每个 GPU 只存储 1/N 的优化器状态 │
│ - 显存节省:4× │
├─────────────────────────────────────────────────────────┤
│ ZeRO-2:分片优化器状态 + 梯度 │
│ - 每个 GPU 只存储对应的梯度 │
│ - 显存节省:8× │
├─────────────────────────────────────────────────────────┤
│ ZeRO-3:分片优化器状态 + 梯度 + 参数 │
│ - 每个 GPU 只存储部分模型参数 │
│ - 显存节省:与 GPU 数量成正比 │
└─────────────────────────────────────────────────────────┘
显存节省效果
以 7B 模型为例,不同 GPU 数量下的显存需求:
| 配置 | 单卡显存 | 8 卡总显存 |
|---|---|---|
| 标准数据并行 | 112 GB | 896 GB |
| ZeRO-1 | 28 GB | 224 GB |
| ZeRO-2 | 16 GB | 128 GB |
| ZeRO-3 | 2 GB | 16 GB |
DeepSpeed 配置
ZeRO-2 配置(推荐用于中等规模模型):
{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true,
"reduce_scatter": true,
"contiguous_gradients": true,
"overlap_comm": true
},
"gradient_accumulation_steps": 1,
"train_micro_batch_size_per_gpu": 4,
"fp16": {
"enabled": true,
"loss_scale": 0,
"initial_scale_power": 16
}
}
ZeRO-3 配置(超大模型):
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu"
},
"offload_param": {
"device": "cpu"
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": 1e6,
"prefetch_bucket_size": 1e6
}
}
ZeRO-Infinity
ZeRO-Infinity 是 ZeRO-3 的扩展,支持将模型状态卸载到 CPU 内存甚至 NVMe SSD:
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "nvme",
"nvme_path": "/local_nvme"
},
"offload_param": {
"device": "nvme",
"nvme_path": "/local_nvme"
}
}
}
ZeRO-Infinity 理论上可以训练无限大小的模型,只要总存储空间足够。
PyTorch FSDP
PyTorch FSDP(Fully Sharded Data Parallel)是 PyTorch 原生的分片数据并行实现,功能类似 ZeRO-3。
基本使用
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
# 初始化分布式环境
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
# 创建模型
model = MyModel()
# 包装为 FSDP
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD, # 类似 ZeRO-3
device_id=torch.cuda.current_device()
)
分片策略
| 策略 | 说明 | 类似 |
|---|---|---|
FULL_SHARD | 分片参数、梯度、优化器状态 | ZeRO-3 |
SHARD_GRAD_OP | 分片梯度和优化器状态 | ZeRO-2 |
NO_SHARD | 不分片 | 标准 DDP |
HYBRID_SHARD | 节点内分片 + 节点间复制 | 混合并行 |
混合精度配置
from torch.distributed.fsdp import MixedPrecision
mp_policy = MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16
)
model = FSDP(
model,
mixed_precision=mp_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD
)
自动包装策略
对于 Transformer 模型,推荐按层包装以平衡通信开销:
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import transformers
auto_wrap_policy = transformer_auto_wrap_policy(
transformer_layer_cls={
transformers.models.llama.modeling_llama.LlamaDecoderLayer
}
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD
)
激活值优化
除了上述技术,还有一些针对激活值的优化方法。
序列并行
对于长序列训练,激活值随序列长度线性增长。序列并行将序列维度切分到多个 GPU:
# Ring Attention 实现序列并行
# 将长序列分割到多个 GPU 上处理
# 每个 GPU 只存储部分序列的 KV Cache
选择性激活重计算
只对显存占用大的激活值进行重计算:
from torch.utils.checkpoint import checkpoint
class TransformerBlock(nn.Module):
def forward(self, x):
# Attention 激活值大,使用检查点
x = x + checkpoint(self.attention, x, use_reentrant=False)
# MLP 激活值相对小,不使用检查点
x = x + self.mlp(x)
return x
显存监控与调试
监控工具
import torch
# 当前显存使用
print(f"已分配: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"已预留: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
print(f"峰值: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
# 重置峰值统计
torch.cuda.reset_peak_memory_stats()
显存分析
import torch.cuda.memory as memory
# 详细显存分析
snapshot = torch.cuda.memory._record_memory_history()
# ... 运行模型 ...
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
常见显存问题
| 问题 | 原因 | 解决方案 |
|---|---|---|
| OOM at init | 模型太大 | 使用 FSDP/ZeRO-3 |
| OOM at forward | 激活值太大 | 梯度检查点、减小批次 |
| OOM at backward | 梯度累积 | 梯度累积、减小批次 |
| 显存碎片化 | 频繁分配释放 | 清空缓存、预分配 |
显存优化清单
□ 检查模型是否可放入单卡
├─ 是 → 考虑数据并行
└─ 否 → 考虑模型并行或 ZeRO
□ 启用混合精度训练
├─ A100/H100 → 使用 BF16
└─ 其他 GPU → 使用 FP16 + GradScaler
□ 评估激活值大小
├─ 批次大/序列长 → 启用梯度检查点
└─ 批次小/序列短 → 可能不需要
□ 优化数据加载
├─ 启用 pin_memory
├─ 调整 num_workers
└─ 考虑预取
□ 监控显存使用
├─ 使用 nvidia-smi
├─ 使用 torch.cuda.memory
└─ 分析内存快照
小结
显存优化是大模型训练和部署的核心技能。本章介绍了:
- 显存分析:理解显存组成和计算方法
- 混合精度训练:FP16/BF16 减少显存占用
- 梯度检查点:以计算换内存
- ZeRO 优化:分片消除冗余
- PyTorch FSDP:原生分片数据并行
选择合适的优化策略组合,可以在有限的显存中训练更大的模型。实际应用中,建议从混合精度训练开始,逐步添加梯度检查点和 ZeRO 优化。
参考资料
官方文档
- DeepSpeed ZeRO 文档 - ZeRO 配置详解
- PyTorch FSDP 教程 - FSDP 使用指南
- PyTorch AMP 文档 - 自动混合精度
技术论文
- ZeRO 论文 - DeepSpeed 显存优化
- Flash Attention 论文 - 高效注意力计算
- PyTorch FSDP 论文 - 分片数据并行