跳到主要内容

故障排查与运维实践

AI 训练系统涉及复杂的硬件、软件和网络组件,故障排查是一项必备技能。本章系统介绍常见问题的诊断方法和运维实践,帮助快速定位和解决问题。

故障分类

AI 训练系统的故障可以分为以下几类:

故障类型常见表现影响范围
硬件故障GPU 损坏、内存错误、电源故障单节点
软件故障CUDA 错误、OOM、框架 bug单任务到全局
网络故障NCCL 超时、连接断开、丢包分布式任务
存储故障I/O 错误、数据损坏、空间不足所有任务
配置错误参数错误、路径错误、权限问题单任务

GPU 相关问题

CUDA 错误

CUDA 错误是最常见的 GPU 问题,表现形式多样。

常见 CUDA 错误及解决方案

错误信息可能原因解决方案
CUDA out of memory显存不足减小批次、启用量化、使用 FSDP
CUDA error: device-side assert triggered索引越界、参数错误检查数据标签范围、调试代码
CUDA error: illegal memory access内存访问越界检查 kernel 代码、使用 cuda-memcheck
CUDA error: invalid device ordinalGPU ID 不存在检查 CUDA_VISIBLE_DEVICES
cudaErrorPeerAccessUnsupportedP2P 访问不支持检查 GPU 拓扑、禁用 P2P

诊断步骤

# 1. 检查 GPU 状态
nvidia-smi

# 2. 检查 CUDA 版本兼容性
nvcc --version
cat /usr/local/cuda/version.txt

# 3. 使用 cuda-memcheck 检测内存错误
cuda-memcheck ./my_program

# 4. 检查 GPU 是否被其他进程占用
nvidia-smi --query-compute-apps=pid,process_name,used_memory --format=csv

# 5. 重置 GPU(谨慎使用)
sudo nvidia-smi --gpu-reset -i 0

代码中的错误处理

import torch

def safe_cuda_operation(func, *args, **kwargs):
"""安全的 CUDA 操作包装器"""
try:
return func(*args, **kwargs)
except RuntimeError as e:
if "CUDA out of memory" in str(e):
# 清空缓存重试
torch.cuda.empty_cache()
try:
return func(*args, **kwargs)
except RuntimeError:
# 仍然失败,记录详细错误
print(f"CUDA OOM: {torch.cuda.memory_summary()}")
raise
else:
raise

# 使用示例
output = safe_cuda_operation(model, input_tensor)

显存不足(OOM)

OOM 是训练中最常遇到的问题。

诊断显存使用

import torch

def print_memory_stats():
"""打印详细的显存统计"""
for i in range(torch.cuda.device_count()):
allocated = torch.cuda.memory_allocated(i) / 1e9
reserved = torch.cuda.memory_reserved(i) / 1e9
max_allocated = torch.cuda.max_memory_allocated(i) / 1e9

# 获取 GPU 总显存
total = torch.cuda.get_device_properties(i).total_memory / 1e9

print(f"GPU {i}:")
print(f" 已分配: {allocated:.2f} GB ({allocated/total*100:.1f}%)")
print(f" 已预留: {reserved:.2f} GB ({reserved/total*100:.1f}%)")
print(f" 峰值: {max_allocated:.2f} GB ({max_allocated/total*100:.1f}%)")
print(f" 总量: {total:.2f} GB")

# 训练前后调用
print("训练前:")
print_memory_stats()

# ... 训练代码 ...

print("\n训练后:")
print_memory_stats()

OOM 解决方案优先级

  1. 减小批次大小:最直接的方法
  2. 启用梯度检查点:以计算换内存
  3. 使用混合精度训练:FP16/BF16
  4. 使用 FSDP/ZeRO:分片存储
  5. 模型量化:INT8/INT4
  6. CPU Offload:卸载部分数据到 CPU
# OOM 应急处理脚本
def handle_oom(model, optimizer, batch, original_batch_size):
"""OOM 时的应急处理"""
torch.cuda.empty_cache()

# 尝试减半批次大小
half_batch_size = original_batch_size // 2
if half_batch_size < 1:
raise RuntimeError("批次大小已最小,无法继续")

print(f"OOM detected, reducing batch size from {original_batch_size} to {half_batch_size}")

# 重新分割数据
half_batch = batch[:half_batch_size]
return model(half_batch)

GPU 利用率低

GPU 利用率低通常意味着存在其他瓶颈。

诊断清单

import torch
import time

def diagnose_bottleneck(model, dataloader, device):
"""诊断训练瓶颈"""
model.eval()

# 测试纯计算时间
dummy_input = torch.randn(32, 3, 224, 224).to(device)

torch.cuda.synchronize()
start = time.time()
for _ in range(100):
_ = model(dummy_input)
torch.cuda.synchronize()
compute_time = (time.time() - start) / 100

print(f"纯计算时间: {compute_time*1000:.2f} ms/batch")

# 测试数据加载时间
dataloader_iter = iter(dataloader)
start = time.time()
for _ in range(10):
batch = next(dataloader_iter)
load_time = (time.time() - start) / 10
print(f"数据加载时间: {load_time*1000:.2f} ms/batch")

# 判断瓶颈
if load_time > compute_time * 2:
print("⚠️ 瓶颈:数据加载")
print(" 建议:增加 num_workers、启用 pin_memory、使用更快的存储")
elif compute_time > load_time * 2:
print("✓ 正常:计算为主")
else:
print("⚠️ 可能存在通信瓶颈")

常见原因及解决方案

原因症状解决方案
数据加载瓶颈GPU 等待数据增加 num_workers、使用更快的存储
小批次计算粒度太小增大批次大小
CPU 预处理瓶颈CPU 占用高优化预处理、使用 GPU 预处理
同步操作过多频繁 GPU-CPU 同步减少同步点、使用异步操作
通信瓶颈多卡时利用率低优化网络、使用梯度累积

分布式训练问题

NCCL 错误

NCCL 是分布式训练的核心通信库,其错误通常与网络相关。

常见 NCCL 错误

错误信息可能原因解决方案
NCCL error: unhandled cuda errorCUDA 操作失败先解决 CUDA 错误
NCCL error: network call failed网络不可达检查网络连通性
NCCL error: remote detached进程异常退出检查其他进程日志
NCCL WARN CANNOT RECV FROM PEERP2P 通信失败检查 NVLink/IB 连接

NCCL 调试环境变量

# 启用详细日志
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL

# 仅打印错误
export NCCL_DEBUG=WARN

# 指定网络接口
export NCCL_SOCKET_IFNAME=eth0
export NCCL_IB_HCA=mlx5_0

# 禁用 InfiniBand(使用以太网)
export NCCL_IB_DISABLE=1

# 增加超时时间
export NCCL_TIMEOUT=1800

# 禁用 P2P(故障排查时)
export NCCL_P2P_DISABLE=1

网络连通性检查

# 检查 TCP 连接
ping <other-node-ip>
nc -zv <other-node-ip> 29500

# 检查 InfiniBand
ibstat
ibping <other-node-guid>

# 检查 NCCL 版本
python -c "import torch; print(torch.cuda.nccl.version())"

死锁问题

分布式训练中的死锁通常由同步操作不当引起。

常见死锁场景

# ❌ 错误:条件同步会导致死锁
if rank == 0:
dist.send(tensor, dst=1)
# rank 0 等待 rank 1 接收
# 但 rank 1 没有执行 recv,死锁!

# ✅ 正确:所有进程都要参与
if rank == 0:
dist.send(tensor, dst=1)
else:
dist.recv(tensor, src=0)

# ❌ 错误:不一致的 barrier
if rank == 0:
# 一些操作
dist.barrier() # 只有 rank 0 调用
# 死锁!

# ✅ 正确:所有进程都调用 barrier
# 所有 rank 都执行
dist.barrier()

死锁检测技巧

# 1. 检查进程状态
ps aux | grep python

# 2. 查看进程卡在哪里
strace -p <pid> -e trace=network

# 3. 检查 NCCL 日志
# 设置 NCCL_DEBUG=INFO 查看通信状态

# 4. 使用超时机制
export NCCL_BLOCKING_WAIT=0 # 非阻塞模式
export NCCL_TIMEOUT=60 # 设置超时

梯度/激活异常

训练过程中的数值问题可能导致梯度爆炸或 NaN。

检测和处理 NaN

def check_nan(tensor, name="tensor"):
"""检查张量是否包含 NaN 或 Inf"""
if torch.isnan(tensor).any():
print(f"⚠️ {name} contains NaN!")
return True
if torch.isinf(tensor).any():
print(f"⚠️ {name} contains Inf!")
return True
return False

def safe_backward(loss, model):
"""安全的反向传播"""
# 检查损失
if check_nan(loss, "loss"):
return False

loss.backward()

# 检查梯度
for name, param in model.named_parameters():
if param.grad is not None:
if check_nan(param.grad, f"gradient of {name}"):
return False

return True

# 训练循环
for batch in dataloader:
optimizer.zero_grad()
loss = model(batch)

if not safe_backward(loss, model):
print("Gradient error detected, skipping batch")
optimizer.zero_grad()
continue

# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

optimizer.step()

数值稳定性最佳实践

问题解决方案
梯度爆炸梯度裁剪、降低学习率
梯度消失使用 BF16、调整模型结构
Loss NaN检查学习率、使用损失缩放、检查数据
数值溢出使用 BF16、调整精度

存储相关问题

I/O 瓶颈

I/O 瓶颈会导致 GPU 利用率低下。

诊断 I/O 性能

# 监控磁盘 I/O
iostat -x 1

# 监控 NFS 统计
nfsstat -c

# 测试存储吞吐
dd if=/dev/zero of=/data/test bs=1M count=1000 conv=fdatasync

# 使用 fio 详细测试
fio --name=randread --ioengine=libaio --iodepth=16 \
--rw=randread --bs=4k --direct=1 --size=1G \
--numjobs=4 --runtime=60 --group_reporting

优化 DataLoader 性能

from torch.utils.data import DataLoader

# 优化配置
dataloader = DataLoader(
dataset,
batch_size=32,
num_workers=8, # 根据 CPU 核心数调整
pin_memory=True, # GPU 训练时启用
prefetch_factor=4, # 增加预取
persistent_workers=True, # 保持 worker 进程
drop_last=True, # 分布式训练时建议启用
)

# 监控数据加载时间
import time

for i, batch in enumerate(dataloader):
if i == 0:
continue # 跳过第一次(包含初始化时间)

start = time.time()
# ... GPU 计算 ...
compute_time = time.time() - start

if i % 100 == 0:
print(f"Batch {i}: compute={compute_time*1000:.1f}ms")

数据损坏

训练数据损坏会导致各种奇怪的错误。

检测和处理数据损坏

import hashlib
import json

def verify_file_integrity(filepath, expected_md5=None):
"""验证文件完整性"""
with open(filepath, 'rb') as f:
file_hash = hashlib.md5(f.read()).hexdigest()

if expected_md5 and file_hash != expected_md5:
raise ValueError(f"File corrupted: {filepath}")

return file_hash

def validate_dataset(dataset_path):
"""验证数据集完整性"""
import os
from pathlib import Path

errors = []
for file_path in Path(dataset_path).rglob('*.json'):
try:
with open(file_path, 'r') as f:
data = json.load(f)
# 检查必要字段
if 'text' not in data:
errors.append(f"{file_path}: missing 'text' field")
except json.JSONDecodeError as e:
errors.append(f"{file_path}: JSON decode error - {e}")
except Exception as e:
errors.append(f"{file_path}: {e}")

if errors:
print(f"Found {len(errors)} errors:")
for error in errors[:10]:
print(f" - {error}")

return len(errors) == 0

运维实践

检查点管理

良好的检查点管理可以在故障时快速恢复。

检查点最佳实践

import os
import torch
import glob
from datetime import datetime

class CheckpointManager:
"""检查点管理器"""

def __init__(self, save_dir, max_to_keep=5, save_best=True):
self.save_dir = save_dir
self.max_to_keep = max_to_keep
self.save_best = save_best
self.best_loss = float('inf')
os.makedirs(save_dir, exist_ok=True)

def save(self, model, optimizer, scheduler, epoch, loss, metrics=None):
"""保存检查点"""
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
'loss': loss,
'best_loss': self.best_loss,
'timestamp': datetime.now().isoformat(),
}

# 添加额外指标
if metrics:
checkpoint['metrics'] = metrics

# 保存当前检查点
path = os.path.join(self.save_dir, f'checkpoint_epoch_{epoch}.pt')
torch.save(checkpoint, path)
print(f"Checkpoint saved: {path}")

# 更新 latest 链接
self._update_latest(path)

# 保存最优检查点
if self.save_best and loss < self.best_loss:
self.best_loss = loss
best_path = os.path.join(self.save_dir, 'checkpoint_best.pt')
torch.save(checkpoint, best_path)
print(f"Best checkpoint updated: loss={loss:.4f}")

# 清理旧检查点
self._cleanup()

def _update_latest(self, path):
"""更新 latest 软链接"""
latest_path = os.path.join(self.save_dir, 'checkpoint_latest.pt')
if os.path.exists(latest_path):
os.remove(latest_path)
os.symlink(os.path.basename(path), latest_path)

def _cleanup(self):
"""清理旧检查点"""
checkpoints = sorted(glob.glob(
os.path.join(self.save_dir, 'checkpoint_epoch_*.pt')
))

while len(checkpoints) > self.max_to_keep:
old_ckpt = checkpoints.pop(0)
os.remove(old_ckpt)
print(f"Removed old checkpoint: {old_ckpt}")

def load(self, model, optimizer=None, scheduler=None, checkpoint_path=None):
"""加载检查点"""
if checkpoint_path is None:
checkpoint_path = os.path.join(self.save_dir, 'checkpoint_latest.pt')

if not os.path.exists(checkpoint_path):
print(f"No checkpoint found at {checkpoint_path}")
return 0

checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])

if optimizer and 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

if scheduler and 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict']:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

self.best_loss = checkpoint.get('best_loss', float('inf'))

print(f"Checkpoint loaded: epoch={checkpoint['epoch']}, loss={checkpoint['loss']:.4f}")
return checkpoint['epoch']

# 使用示例
manager = CheckpointManager('checkpoints', max_to_keep=3)

# 恢复训练
start_epoch = manager.load(model, optimizer, scheduler)

for epoch in range(start_epoch, num_epochs):
loss = train_epoch(model, optimizer, dataloader)
manager.save(model, optimizer, scheduler, epoch, loss)

自动恢复机制

对于生产环境,需要自动恢复机制。

import signal
import sys

class GracefulKiller:
"""优雅退出处理器"""
def __init__(self):
self.kill_now = False
signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully)

def exit_gracefully(self, signum, frame):
print(f"\nReceived signal {signum}, saving checkpoint...")
self.kill_now = True

# 训练循环
killer = GracefulKiller()
checkpoint_manager = CheckpointManager('checkpoints')

for epoch in range(num_epochs):
if killer.kill_now:
# 保存检查点后退出
checkpoint_manager.save(model, optimizer, scheduler, epoch, current_loss)
sys.exit(0)

# 正常训练
loss = train_epoch(model, optimizer, dataloader)

日志收集

集中收集日志便于分析。

import logging
import socket
from datetime import datetime

def setup_training_logger(log_dir, rank=0):
"""设置训练日志"""
os.makedirs(log_dir, exist_ok=True)

# 创建 logger
logger = logging.getLogger('training')
logger.setLevel(logging.INFO)

# 文件处理器
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
hostname = socket.gethostname()
log_file = os.path.join(log_dir, f'train_{hostname}_rank{rank}_{timestamp}.log')

file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter(
'%(asctime)s - %(levelname)s - %(message)s'
))

# 控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)

logger.addHandler(file_handler)
logger.addHandler(console_handler)

return logger

# 使用
logger = setup_training_logger('logs', rank=local_rank)
logger.info(f"Training started on {socket.gethostname()}")
logger.info(f"Model: {model.__class__.__name__}")
logger.info(f"GPU: {torch.cuda.get_device_name()}")

健康检查

定期健康检查可以在问题发生前发现隐患。

def health_check():
"""系统健康检查"""
issues = []

# 1. 检查 GPU 状态
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
temp = torch.cuda.temperature(i)

if temp > 85:
issues.append(f"GPU {i} temperature too high: {temp}°C")

# 检查 ECC 错误
# (需要 nvidia-smi 命令)

# 2. 检查磁盘空间
import shutil
total, used, free = shutil.disk_usage('/')
free_percent = free / total * 100
if free_percent < 10:
issues.append(f"Disk space low: {free_percent:.1f}% free")

# 3. 检查内存
import psutil
mem = psutil.virtual_memory()
if mem.percent > 90:
issues.append(f"Memory usage high: {mem.percent}%")

# 4. 检查网络
# (分布式训练时)

if issues:
print("⚠️ Health check failed:")
for issue in issues:
print(f" - {issue}")
return False

print("✓ Health check passed")
return True

# 定期检查
import threading

def periodic_health_check(interval=300):
"""定期健康检查"""
def check():
while True:
health_check()
time.sleep(interval)

thread = threading.Thread(target=check, daemon=True)
thread.start()

故障排查流程

当问题发生时,建议按以下流程排查:

1. 收集信息
├─ 查看错误日志
├─ 查看 GPU 状态
├─ 查看网络状态
└─ 查看存储状态

2. 定位问题
├─ 确定影响范围(单卡/单机/多机)
├─ 确定问题类型(硬件/软件/网络)
└─ 复现问题

3. 解决问题
├─ 查阅文档和已知问题
├─ 尝试临时解决方案
├─ 验证修复效果
└─ 记录解决方案

4. 预防措施
├─ 添加监控和告警
├─ 更新文档
└─ 改进测试

小结

故障排查是 AI 训练运维的核心技能。本章介绍了:

  1. GPU 问题:CUDA 错误、OOM、利用率低的诊断和解决
  2. 分布式训练问题:NCCL 错误、死锁、数值异常的处理
  3. 存储问题:I/O 瓶颈和数据损坏的诊断
  4. 运维实践:检查点管理、自动恢复、日志收集、健康检查
  5. 排查流程:系统化的故障排查方法

建立完善的监控体系、养成良好的运维习惯、掌握系统化的排查方法,可以显著提升训练系统的稳定性和运维效率。

参考资料

官方文档

工具文档