跳到主要内容

分布式训练

随着模型规模的不断增长,单机训练已经无法满足需求。分布式训练通过多机多卡的协同工作,实现了大规模模型的高效训练。本章将系统介绍分布式训练的核心概念、常用框架和实践技巧。

为什么需要分布式训练

模型规模的增长

近年来,模型参数量呈指数级增长:

模型参数量发布年份
BERT-Large3.4亿2018
GPT-215亿2019
GPT-31750亿2020
LLaMA-2 70B700亿2023
GPT-4约1.8万亿2023

一个 1750 亿参数的模型,仅模型参数就需要约 700GB 显存(FP32),远超单卡容量。

训练时间的需求

即使模型可以放入单卡,训练时间也可能难以接受。以 GPT-3 为例:

  • 单卡训练时间:约 355 年
  • 使用 1024 张 V100:约 1 个月

分布式训练可以将训练时间从年缩短到周甚至天。

显存需求分析

理解显存消耗对于选择正确的分布式策略至关重要:

模型显存需求(FP16 训练)

模型大小参数显存梯度优化器状态激活值总计(约)
7B14GB14GB84GB10-30GB120GB+
13B26GB26GB156GB20-50GB230GB+
70B140GB140GB840GB50-100GB1.2TB+

优化器状态包括 FP32 参数副本、动量和方差,通常占参数大小的 6 倍(AdamW)。这就是为什么即使模型参数能放入显存,训练仍需要更多内存。

分布式训练策略

数据并行

数据并行是最常用的分布式训练策略。

工作原理

  1. 每个 GPU 持有完整的模型副本
  2. 将批次数据分割到各个 GPU
  3. 各 GPU 独立进行前向和反向传播
  4. 同步梯度,更新模型参数
┌─────────────────────────────────────────────────────┐
│ 数据并行 │
│ │
│ GPU 0 GPU 1 GPU 2 GPU 3 │
│ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐│
│ │模型 │ │模型 │ │模型 │ │模型 ││
│ │副本 │ │副本 │ │副本 │ │副本 ││
│ └─────┘ └─────┘ └─────┘ └─────┘│
│ ↓ ↓ ↓ ↓ │
│ 数据 0 数据 1 数据 2 数据 3 │
│ ↓ ↓ ↓ ↓ │
│ 梯度 0 梯度 1 梯度 2 梯度 3 │
│ └──────────────┴──────────────┴──────────────┘ │
│ ↓ │
│ 梯度同步 │
│ ↓ │
│ 参数更新 │
└─────────────────────────────────────────────────────┘

优点

  • 实现简单,代码改动小
  • 适用于模型可以放入单卡显存的场景

缺点

  • 每个GPU需要存储完整模型,显存利用率低
  • 通信开销随 GPU 数量增加而增大

适用场景判断

模型大小GPU 显存推荐策略
< 1B24GB+数据并行
1B - 7B80GB数据并行 + ZeRO
7B - 30B80GB × 8数据并行 + ZeRO-2/3
> 30B需要多机混合并行

模型并行

当模型太大无法放入单卡时,需要使用模型并行。

流水线并行

将模型按层切分到不同 GPU:

GPU 0        GPU 1        GPU 2        GPU 3
┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐
│层 0-5│ ──→ │层 6-11│ ──→ │层 12-17│ ──→ │层 18-23│
└─────┘ └─────┘ └─────┘ └─────┘

特点

  • 各 GPU 顺序处理,存在"气泡"时间
  • 需要精心设计 micro-batch 来减少气泡

气泡问题与解决

简单的流水线会产生大量空闲时间:

简单流水线:
GPU 0: [F0]----[F1]----[F2]----[F3]
GPU 1: [等待] [F0]----[F1]----[F2]----[F3]
GPU 2: [等待] [等待] [F0]----[F1]----[F2]
GPU 3: [等待] [等待] [等待] [F0]----

使用 micro-batch 后:
GPU 0: [F0][F1][F2][F3][B0][B1][B2][B3]
GPU 1: [F0][F1][F2][F3][B0][B1][B2][B3]
GPU 2: [F0][F1][F2][F3][B0][B1][B2][B3]
GPU 3: [F0][F1][F2][F3][B0][B1][B2][B3]

F = 前向传播, B = 反向传播

张量并行

将单个层的参数切分到多个 GPU:

矩阵乘法 Y = XW,将 W 按列切分:

GPU 0: Y₁ = XW₁
GPU 1: Y₂ = XW₂

最终结果: Y = [Y₁, Y₂](拼接)

特点

  • 通信频繁,需要高带宽互联
  • 适合单机多卡场景

张量并行实现细节

对于 Transformer 的注意力层和 MLP 层:

# 注意力层的张量并行
# 将 Q、K、V 投影矩阵切分到不同 GPU
# 每个 GPU 计算部分注意力头

# MLP 层的张量并行
# 将第一个线性层按列切分
# 将第二个线性层按行切分
# 这样中间层不需要通信

混合并行

实际的大模型训练通常结合多种并行策略:

┌─────────────────────────────────────────────────────┐
│ 混合并行 │
│ │
│ 数据并行(跨节点) │
│ ┌─────────────────────────────────────────────┐ │
│ │ 节点 0 │ │
│ │ ┌───────────────────────────────────────┐ │ │
│ │ │ 流水线并行(层间) │ │ │
│ │ │ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ │ │
│ │ │ │ 阶段 0 │ │ 阶段 1 │ │ 阶段 2 │ │ │ │
│ │ │ │ ┌─────┐ │ │ ┌─────┐ │ │ ┌─────┐ │ │ │ │
│ │ │ │ │TP 0 │ │ │ │TP 0 │ │ │ │TP 0 │ │ │ │ │
│ │ │ │ │TP 1 │ │ │ │TP 1 │ │ │ │TP 1 │ │ │ │ │
│ │ │ │ └─────┘ │ │ └─────┘ │ │ └─────┘ │ │ │ │
│ │ │ └─────────┘ └─────────┘ └─────────┘ │ │ │
│ │ └───────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────┘

TP = 张量并行(层内)

并行策略选择指南

场景模型大小GPU 数量推荐策略
单机训练< 7B8数据并行 + ZeRO
单机大模型7B-30B8张量并行 + ZeRO
多机训练30B-70B32+2D 并行 + ZeRO
超大模型> 70B128+3D 并行 + ZeRO-3

PyTorch 分布式训练

DistributedDataParallel(DDP)

PyTorch DDP 是最常用的数据并行实现:

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
dist.destroy_process_group()

def train(rank, world_size):
setup(rank, world_size)

model = MyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])

optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.01)

for epoch in range(epochs):
for batch in dataloader:
optimizer.zero_grad()
output = ddp_model(batch)
loss = criterion(output, target)
loss.backward()
optimizer.step()

cleanup()

if __name__ == "__main__":
world_size = torch.cuda.device_count()
torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size)

启动方式

torchrun --nproc_per_node=8 train.py

多机训练启动

# 节点 0(主节点)
torchrun --nnodes=2 --nproc_per_node=8 --node_rank=0 \
--master_addr="10.0.0.1" --master_port=29500 train.py

# 节点 1
torchrun --nnodes=2 --nproc_per_node=8 --node_rank=1 \
--master_addr="10.0.0.1" --master_port=29500 train.py

DDP 完整示例

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import os

def main():
# 初始化分布式环境
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

# 创建模型并移到当前 GPU
model = MyModel().cuda()
model = DDP(model, device_ids=[local_rank])

# 创建分布式数据采样器
train_dataset = MyDataset(...)
train_sampler = DistributedSampler(
train_dataset,
num_replicas=dist.get_world_size(),
rank=dist.get_rank()
)

train_loader = DataLoader(
train_dataset,
batch_size=32,
sampler=train_sampler,
num_workers=4,
pin_memory=True
)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

for epoch in range(num_epochs):
# 每个 epoch 开始前设置 sampler 的 epoch
train_sampler.set_epoch(epoch)

for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.cuda(), target.cuda()

optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()

if batch_idx % 100 == 0 and local_rank == 0:
print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}")

dist.destroy_process_group()

if __name__ == "__main__":
main()

Fully Sharded Data Parallel(FSDP)

FSDP 是 PyTorch 提供的高级分布式训练方案,通过分片技术大幅降低显存占用:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy

model = MyModel()
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
device_id=torch.cuda.current_device()
)

分片策略

策略说明显存节省
FULL_SHARD分片参数、梯度、优化器状态最大
SHARD_GRAD_OP分片梯度和优化器状态中等
NO_SHARD不分片(类似 DDP)

FSDP 完整配置示例

from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
MixedPrecision,
CPUOffload
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import transformers

# 混合精度配置
mp_policy = MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16
)

# 自动包装策略(按 Transformer 层包装)
auto_wrap_policy = transformer_auto_wrap_policy(
transformer_layer_cls={
transformers.models.llama.modeling_llama.LlamaDecoderLayer
}
)

# 创建 FSDP 模型
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mp_policy,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device()
)

DeepSpeed

DeepSpeed 是微软开源的深度学习优化库,提供了强大的显存优化和分布式训练能力。

核心特性

ZeRO 优化:零冗余优化器,通过分片消除冗余:

  • ZeRO-1:分片优化器状态
  • ZeRO-2:分片优化器状态和梯度
  • ZeRO-3:分片优化器状态、梯度和参数

显存对比

配置7B 模型显存
DDP112 GB
ZeRO-128 GB
ZeRO-216 GB
ZeRO-38 GB

使用示例

配置文件 ds_config.json

{
"train_batch_size": 128,
"gradient_accumulation_steps": 1,
"optimizer": {
"type": "AdamW",
"params": {
"lr": 1e-5,
"betas": [0.9, 0.999],
"eps": 1e-8
}
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
},
"allgather_partitions": true,
"reduce_scatter": true
},
"fp16": {
"enabled": true,
"loss_scale": 0,
"initial_scale_power": 16
}
}

训练代码:

import deepspeed

model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
config="ds_config.json"
)

for batch in dataloader:
outputs = model_engine(batch)
loss = outputs.loss
model_engine.backward(loss)
model_engine.step()

启动方式

deepspeed --num_gpus=8 train.py --deepspeed_config ds_config.json

ZeRO 配置详解

ZeRO-1 配置(适合显存充足场景)

{
"zero_optimization": {
"stage": 1,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8
}
}

ZeRO-2 配置(推荐)

{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true,
"reduce_scatter": true,
"contiguous_gradients": true
}
}

ZeRO-3 配置(超大模型)

{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu"
},
"offload_param": {
"device": "cpu"
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": 1e6,
"prefetch_bucket_size": 1e6
}
}

ZeRO-Offload vs ZeRO-Infinity

特性ZeRO-OffloadZeRO-Infinity
CPU 内存优化器状态参数 + 优化器状态
NVMe 支持
适用模型< 30B任意大小
硬件要求大内存 CPUNVMe SSD

Megatron-LM

Megatron-LM 是 NVIDIA 开发的大模型训练框架,专注于高效的张量并行和流水线并行。

核心技术

张量并行:将 Transformer 层的计算分散到多个 GPU

from megatron import get_args
from megatron.model import TransformerBlock

# Megatron 自动处理张量并行
model = TransformerBlock(...)

流水线并行:使用 1F1B 调度减少气泡

时间步:  0  1  2  3  4  5  6  7  8  9
GPU 0: F0 F1 F2 F3 F4 B0 B1 B2 B3 B4
GPU 1: F0 F1 F2 F3 F4 B0 B1 B2 B3
GPU 2: F0 F1 F2 F3 F4 B0 B1 B2
GPU 3: F0 F1 F2 F3 F4 B0 B1

F = 前向传播, B = 反向传播

通信优化

梯度累积

减少通信频率,提高计算通信比:

accumulation_steps = 4

for i, batch in enumerate(dataloader):
output = model(batch)
loss = output.loss / accumulation_steps
loss.backward()

if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()

梯度压缩

减少通信数据量:

  • 量化:将 FP32 梯度压缩为 FP16 或 INT8
  • 稀疏化:只传输重要的梯度

重叠计算和通信

在反向传播时同步梯度:

# DDP 自动实现计算通信重叠
model = DDP(model, device_ids=[rank])

显存优化技术

混合精度训练

使用 FP16/BF16 进行计算,减少显存占用:

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in dataloader:
optimizer.zero_grad()

with autocast():
output = model(batch)
loss = criterion(output, target)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

BF16 vs FP16

特性FP16BF16
动态范围有限与 FP32 相同
精度较高较低
数值稳定性可能溢出更稳定
硬件支持广泛A100+

推荐使用 BF16:如果硬件支持(Ampere 架构及以后),BF16 是更好的选择,因为它避免了 FP16 的数值溢出问题。

梯度检查点

以计算换内存:

from torch.utils.checkpoint import checkpoint

class MyModel(nn.Module):
def forward(self, x):
# 使用检查点保存内存
x = checkpoint(self.layer1, x, use_reentrant=False)
x = checkpoint(self.layer2, x, use_reentrant=False)
return x

梯度检查点的显存-计算权衡

模型层数无检查点显存有检查点显存计算开销
1216GB8GB+33%
2432GB12GB+33%
4864GB16GB+33%

激活重计算

DeepSpeed 提供的激活重计算功能:

{
"activation_checkpointing": {
"partition_activations": true,
"cpu_checkpointing": true
}
}

Flash Attention

Flash Attention 是一种高效的注意力计算方法,通过减少 HBM 访问次数显著提升速度:

# PyTorch 2.0+ 内置支持
import torch.nn.functional as F

# 使用 scaled_dot_product_attention(自动选择最优实现)
output = F.scaled_dot_product_attention(query, key, value)

# 或使用 flash-attn 库
from flash_attn import flash_attn_func
output = flash_attn_func(q, k, v, causal=True)

Flash Attention 的优势

特性传统注意力Flash Attention
内存复杂度O(n²)O(n)
HBM 访问多次一次
速度基准2-4× 提升
序列长度受限可扩展

框架选择指南

面对 FSDP 和 DeepSpeed,选择合适的框架是实际开发中的常见问题。

FSDP vs DeepSpeed 对比

特性PyTorch FSDPDeepSpeed ZeRO
显存优化ZeRO-3 级别ZeRO-1/2/3 可选
CPU Offload支持支持更完善
易用性PyTorch 原生,更简单需要配置文件
调试更容易较复杂
生态系统PyTorch 官方支持微软支持,社区大
最新优化持续更新功能更丰富

性能对比

根据实际测试数据,两个框架在不同场景下的表现:

场景FSDPDeepSpeed ZeRO-3推荐
单机 8 卡 7B 模型更快(约 10%)较慢FSDP
多机训练良好良好均可
显存紧张(CPU Offload)支持更成熟DeepSpeed
快速原型开发更简单需要配置FSDP
生产环境长期训练良好功能更全DeepSpeed

选择建议

选择 FSDP 的情况

  • 使用 PyTorch 2.0+ 且希望保持技术栈一致
  • 需要快速开发和调试
  • 模型规模在 7B-30B 范围
  • 单机多卡训练场景
  • 希望代码更简洁,避免配置文件

选择 DeepSpeed 的情况

  • 需要使用 ZeRO-Infinity(NVMe Offload)
  • 训练超大模型(70B+)
  • 需要更灵活的显存管理策略
  • 已经在使用 DeepSpeed 生态系统
  • 需要与 Megatron-LM 结合使用

迁移建议

从 DeepSpeed 迁移到 FSDP 的代码改动通常不大:

# DeepSpeed 方式
import deepspeed
model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
config="ds_config.json"
)

# FSDP 方式
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = FSDP(model, sharding_strategy=ShardingStrategy.FULL_SHARD)

如果项目中已经在使用 Hugging Face Transformers,可以通过 accelerate 库简化两者的切换:

from accelerate import Accelerator
accelerator = Accelerator()
model, optimizer = accelerator.prepare(model, optimizer)

实践建议

选择并行策略

场景推荐策略
模型可放入单卡数据并行(DDP)
模型较大但单机可放FSDP 或 ZeRO-2
超大模型多机训练3D 并行 + ZeRO-3

配置模板

单机 8 卡训练 7B 模型

# 使用 DeepSpeed ZeRO-2
deepspeed --num_gpus=8 train.py \
--deepspeed_config ds_config_z2.json \
--model_name meta-llama/Llama-2-7b-hf \
--batch_size 4 \
--learning_rate 1e-5

多机训练 70B 模型

# 使用 DeepSpeed ZeRO-3 + 流水线并行
deepspeed --num_nodes=4 --num_gpus=8 \
--hostfile hostfile train.py \
--deepspeed_config ds_config_z3.json \
--model_name meta-llama/Llama-2-70b-hf \
--pipeline_parallel_size 4

监控和调试

  • 监控 GPU 利用率:使用 nvidia-smi 或 Prometheus
  • 分析性能瓶颈:使用 PyTorch Profiler
  • 检查梯度同步:确保所有 GPU 梯度一致

性能分析示例

import torch.profiler as profiler

with profiler.profile(
activities=[
profiler.ProfilerActivity.CPU,
profiler.ProfilerActivity.CUDA,
],
on_trace_ready=profiler.tensorboard_trace_handler('./logs'),
record_shapes=True,
profile_memory=True
) as p:
model(input)

# 查看 TensorBoard
# tensorboard --logdir=./logs

常见问题

死锁:确保所有进程执行相同的通信操作

# 错误:可能导致死锁
if rank == 0:
dist.send(tensor, dst=1)

# 正确:所有进程都参与
if rank == 0:
dist.send(tensor, dst=1)
else:
dist.recv(tensor, src=0)

梯度爆炸/消失:检查学习率和梯度裁剪

# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

显存不足:尝试 FSDP、ZeRO 或梯度检查点

# 检查显存使用
print(f"显存已分配: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"显存已预留: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

性能调优清单

检查项说明预期效果
批次大小尽可能大以充分利用 GPU吞吐量提升
梯度累积小显存时增大有效批次显存节省
混合精度使用 FP16/BF16显存减半
梯度检查点以计算换内存显存节省 30-50%
DataLoader多 worker + pin_memoryI/O 加速
通信重叠计算和通信并行隐藏通信延迟

小结

分布式训练是大模型时代的必备技能。本章介绍了:

  1. 并行策略:数据并行、模型并行、混合并行的原理和选择
  2. PyTorch 工具:DDP 和 FSDP 的使用方法
  3. DeepSpeed:ZeRO 优化的配置和使用
  4. 显存优化:混合精度、梯度检查点、Flash Attention
  5. 实践建议:配置模板、监控调试、常见问题

选择合适的并行策略、优化通信效率、合理管理显存,是构建高效训练系统的关键。在实际应用中,建议优先使用成熟的框架(如 DeepSpeed、Megatron-LM),它们已经解决了大部分工程难题。

参考资料

官方文档

技术论文