跳到主要内容

显存优化技术

GPU 显存是训练和部署大模型的核心瓶颈。理解显存的组成和优化方法,对于高效利用硬件资源至关重要。本章将系统介绍显存优化核心技术,帮助你在有限的显存中训练更大的模型。

显存消耗分析

显存组成部分

训练过程中,GPU 显存主要消耗在以下几个方面:

┌─────────────────────────────────────────────────────────┐
│ GPU 显存组成 │
├─────────────────────────────────────────────────────────┤
│ 模型参数 (Parameters) │
│ - FP16/BF16 存储 │
│ - 与参数量成正比 │
├─────────────────────────────────────────────────────────┤
│ 梯度 (Gradients) │
│ - 与模型参数大小相同 │
│ - 反向传播时存储 │
├─────────────────────────────────────────────────────────┤
│ 优化器状态 (Optimizer States) │
│ - AdamW: 参数副本(FP32) + 动量 + 方差 │
│ - 约为参数大小的 6 倍 │
├─────────────────────────────────────────────────────────┤
│ 激活值 (Activations) │
│ - 前向传播中间结果 │
│ - 与批次大小、序列长度相关 │
├─────────────────────────────────────────────────────────┤
│ 临时缓存 (Temporary Buffers) │
│ - CUDA 核函数临时存储 │
│ - 通信缓冲区 │
└─────────────────────────────────────────────────────────┘

显存计算公式

对于参数量为 PP 的模型,使用 AdamW 优化器和混合精度训练:

组件显存占用(FP16 训练)说明
模型参数2P2PFP16 存储
梯度2P2PFP16 存储
优化器状态12P12PFP32 参数副本 + 动量 + 方差
激活值变化与批次大小和序列长度相关

总计约 16P16P 字节,意味着:

  • 7B 模型:约 112 GB
  • 13B 模型:约 208 GB
  • 70B 模型:约 1.12 TB

实际案例:Llama-2-7B 显存分析

假设在 A100 80GB 上训练 Llama-2-7B:

组件显存占用比例
模型参数14 GB12.5%
梯度14 GB12.5%
优化器状态84 GB75%
激活值~10-30 GB变化
总计122-142 GB-

显然,单卡 A100 80GB 无法直接训练 7B 模型,需要使用显存优化技术。

混合精度训练

混合精度训练通过使用 FP16/BF16 进行计算,同时保留 FP32 精度的优化器状态,在减少显存的同时保证训练稳定性。

原理

┌─────────────────────────────────────────────────────────┐
│ 混合精度训练流程 │
├─────────────────────────────────────────────────────────┤
│ 前向传播:FP16 计算 │
│ ↓ │
│ 损失计算:FP32 精度(防止下溢) │
│ ↓ │
│ 损失缩放:放大损失值 │
│ ↓ │
│ 反向传播:FP16 计算梯度 │
│ ↓ │
│ 梯度反缩放:恢复原始梯度值 │
│ ↓ │
│ 参数更新:FP32 优化器状态 │
└─────────────────────────────────────────────────────────┘

PyTorch 自动混合精度

import torch
from torch.cuda.amp import autocast, GradScaler

model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = GradScaler() # 梯度缩放器

for batch in dataloader:
optimizer.zero_grad()

# 前向传播使用自动混合精度
with autocast():
output = model(batch)
loss = criterion(output, target)

# 反向传播
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

BF16 vs FP16

BF16(Brain Float 16)是另一种 16 位浮点格式,相比 FP16 有更大的动态范围:

特性FP16BF16FP32
位数161632
指数位588
尾数位10723
动态范围2142^{-14}2152^{15}与 FP32 相同21262^{-126}21272^{127}
精度较高较低最高

BF16 的优势

  • 动态范围与 FP32 相同,避免数值溢出/下溢
  • 不需要损失缩放
  • 训练更稳定

使用 BF16

# PyTorch 1.10+ 支持 BF16
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
output = model(batch)
loss = criterion(output, target)

# BF16 不需要 GradScaler
loss.backward()
optimizer.step()

硬件要求:BF16 需要 Ampere 架构及以上的 GPU(A100、RTX 3090/4090、H100)。

梯度检查点

梯度检查点(Gradient Checkpointing)是一种以计算换内存的技术,通过在反向传播时重新计算中间激活值来减少显存占用。

原理

标准训练:
前向传播:计算并存储所有激活值
反向传播:读取存储的激活值计算梯度
显存占用:O(层数 × 批次大小 × 序列长度)

梯度检查点:
前向传播:只存储部分检查点的激活值
反向传播:从检查点重新计算需要的激活值
显存占用:O(√(层数) × 批次大小 × 序列长度)

PyTorch 使用方法

from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
def __init__(self, use_checkpoint=False):
super().__init__()
self.use_checkpoint = use_checkpoint
self.attention = MultiHeadAttention()
self.mlp = MLP()

def forward(self, x):
if self.use_checkpoint and self.training:
# 使用梯度检查点
x = checkpoint(self._forward, x, use_reentrant=False)
else:
x = self._forward(x)
return x

def _forward(self, x):
x = x + self.attention(x)
x = x + self.mlp(x)
return x

PyTorch 2.0+ 简化使用

# 设置整个模型使用梯度检查点
model.gradient_checkpointing_enable()

# 或在 Hugging Face 模型中
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.float16,
device_map="auto"
)
model.gradient_checkpointing_enable()

显存-计算权衡

配置显存占用计算开销适用场景
无检查点100%基准显存充足
全检查点30-40%+33%显存紧张
选择性检查点50-60%+15%平衡选择

最佳实践

  • 对显存占用大的层(如 Attention)启用检查点
  • 对计算密集的层(如 LayerNorm)不启用检查点

ZeRO 优化

ZeRO(Zero Redundancy Optimizer)是 DeepSpeed 提供的显存优化技术,通过分片消除数据并行中的冗余。

三阶段优化

┌─────────────────────────────────────────────────────────┐
│ ZeRO 优化阶段 │
├─────────────────────────────────────────────────────────┤
│ ZeRO-1:分片优化器状态 │
│ - 每个 GPU 只存储 1/N 的优化器状态 │
│ - 显存节省:4× │
├─────────────────────────────────────────────────────────┤
│ ZeRO-2:分片优化器状态 + 梯度 │
│ - 每个 GPU 只存储对应的梯度 │
│ - 显存节省:8× │
├─────────────────────────────────────────────────────────┤
│ ZeRO-3:分片优化器状态 + 梯度 + 参数 │
│ - 每个 GPU 只存储部分模型参数 │
│ - 显存节省:与 GPU 数量成正比 │
└─────────────────────────────────────────────────────────┘

显存节省效果

以 7B 模型为例,不同 GPU 数量下的显存需求:

配置单卡显存8 卡总显存
标准数据并行112 GB896 GB
ZeRO-128 GB224 GB
ZeRO-216 GB128 GB
ZeRO-32 GB16 GB

DeepSpeed 配置

ZeRO-2 配置(推荐用于中等规模模型)

{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true,
"reduce_scatter": true,
"contiguous_gradients": true,
"overlap_comm": true
},
"gradient_accumulation_steps": 1,
"train_micro_batch_size_per_gpu": 4,
"fp16": {
"enabled": true,
"loss_scale": 0,
"initial_scale_power": 16
}
}

ZeRO-3 配置(超大模型)

{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu"
},
"offload_param": {
"device": "cpu"
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": 1e6,
"prefetch_bucket_size": 1e6
}
}

ZeRO-Infinity

ZeRO-Infinity 是 ZeRO-3 的扩展,支持将模型状态卸载到 CPU 内存甚至 NVMe SSD:

{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "nvme",
"nvme_path": "/local_nvme"
},
"offload_param": {
"device": "nvme",
"nvme_path": "/local_nvme"
}
}
}

ZeRO-Infinity 理论上可以训练无限大小的模型,只要总存储空间足够。

PyTorch FSDP

PyTorch FSDP(Fully Sharded Data Parallel)是 PyTorch 原生的分片数据并行实现,功能类似 ZeRO-3。

基本使用

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy

# 初始化分布式环境
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

# 创建模型
model = MyModel()

# 包装为 FSDP
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD, # 类似 ZeRO-3
device_id=torch.cuda.current_device()
)

分片策略

策略说明类似
FULL_SHARD分片参数、梯度、优化器状态ZeRO-3
SHARD_GRAD_OP分片梯度和优化器状态ZeRO-2
NO_SHARD不分片标准 DDP
HYBRID_SHARD节点内分片 + 节点间复制混合并行

混合精度配置

from torch.distributed.fsdp import MixedPrecision

mp_policy = MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16
)

model = FSDP(
model,
mixed_precision=mp_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD
)

自动包装策略

对于 Transformer 模型,推荐按层包装以平衡通信开销:

from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import transformers

auto_wrap_policy = transformer_auto_wrap_policy(
transformer_layer_cls={
transformers.models.llama.modeling_llama.LlamaDecoderLayer
}
)

model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD
)

激活值优化

除了上述技术,还有一些针对激活值的优化方法。

序列并行

对于长序列训练,激活值随序列长度线性增长。序列并行将序列维度切分到多个 GPU:

# Ring Attention 实现序列并行
# 将长序列分割到多个 GPU 上处理
# 每个 GPU 只存储部分序列的 KV Cache

选择性激活重计算

只对显存占用大的激活值进行重计算:

from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
def forward(self, x):
# Attention 激活值大,使用检查点
x = x + checkpoint(self.attention, x, use_reentrant=False)
# MLP 激活值相对小,不使用检查点
x = x + self.mlp(x)
return x

显存监控与调试

监控工具

import torch

# 当前显存使用
print(f"已分配: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"已预留: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
print(f"峰值: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")

# 重置峰值统计
torch.cuda.reset_peak_memory_stats()

显存分析

import torch.cuda.memory as memory

# 详细显存分析
snapshot = torch.cuda.memory._record_memory_history()
# ... 运行模型 ...
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")

常见显存问题

问题原因解决方案
OOM at init模型太大使用 FSDP/ZeRO-3
OOM at forward激活值太大梯度检查点、减小批次
OOM at backward梯度累积梯度累积、减小批次
显存碎片化频繁分配释放清空缓存、预分配

显存优化清单

□ 检查模型是否可放入单卡
├─ 是 → 考虑数据并行
└─ 否 → 考虑模型并行或 ZeRO

□ 启用混合精度训练
├─ A100/H100 → 使用 BF16
└─ 其他 GPU → 使用 FP16 + GradScaler

□ 评估激活值大小
├─ 批次大/序列长 → 启用梯度检查点
└─ 批次小/序列短 → 可能不需要

□ 优化数据加载
├─ 启用 pin_memory
├─ 调整 num_workers
└─ 考虑预取

□ 监控显存使用
├─ 使用 nvidia-smi
├─ 使用 torch.cuda.memory
└─ 分析内存快照

小结

显存优化是大模型训练和部署的核心技能。本章介绍了:

  1. 显存分析:理解显存组成和计算方法
  2. 混合精度训练:FP16/BF16 减少显存占用
  3. 梯度检查点:以计算换内存
  4. ZeRO 优化:分片消除冗余
  5. PyTorch FSDP:原生分片数据并行

选择合适的优化策略组合,可以在有限的显存中训练更大的模型。实际应用中,建议从混合精度训练开始,逐步添加梯度检查点和 ZeRO 优化。

参考资料

官方文档

技术论文