存储系统
存储系统是 AI 基础设施的重要组成部分。训练数据集、模型检查点、日志等都需要高效的存储支持。
AI 存储的挑战
AI 工作负载对存储系统提出了特殊要求:
| 挑战 | 说明 |
|---|---|
| 数据规模 | 训练数据集可达 TB 到 PB 级别 |
| 读取模式 | 大量小文件随机读取、大文件顺序读取 |
| 并发访问 | 数百个 GPU 同时读取数据 |
| 检查点写入 | 周期性写入大文件,要求高吞吐 |
| 成本控制 | 大量数据存储成本高昂 |
存储层次架构
一个典型的 AI 存储架构包含多个层次:
┌─────────────────────────────────────────────────────────┐
│ 应用层 │
│ PyTorch DataLoader │ TensorFlow tf.data │
├─────────────────────────────────────────────────────────┤
│ 缓存层 │
│ Alluxio │ Redis │ 本地 SSD │
├─────────────────────────────────────────────────────────┤
│ 文件系统层 │
│ NFS │ Lustre │ GPFS │ CephFS │
├─────────────────────────────────────────────────────────┤
│ 对象存储层 │
│ S3 │ MinIO │ Ceph RGW │
├─────────────────────────────────────────────────────────┤
│ 块存储层 │
│ SAN │ 本地 NVMe │ 云盘 │
└─────────────────────────────────────────────────────────┘
并行文件系统
Lustre
Lustre 是高性能计算领域广泛使用的并行文件系统。
架构组件:
- MDS(Metadata Server):管理文件元数据
- OSS(Object Storage Server):存储实际数据
- Client:客户端挂载点
特点:
- 高吞吐量:可达数百 GB/s
- POSIX 兼容:支持标准文件操作
- 可扩展:支持数千个客户端
配置示例:
# 客户端挂载
mount -t lustre mds.example.com:/fs /mnt/lustre
# 性能调优
lctl set_param lru_cache=1000000
GPFS(IBM Storage Scale)
IBM 开发的高性能并行文件系统:
特点:
- 企业级可靠性
- 优秀的元数据性能
- 支持文件快照和复制
BeeGFS
开源的并行文件系统,部署简单:
# 安装服务端
apt-get install beegfs-mgmtd beegfs-meta beegfs-storage
# 配置存储目标
/opt/beegfs/sbin/beegfs-setup-storage -p /data/beegfs -s 1
# 客户端挂载
/opt/beegfs/sbin/beegfs-setup-client -m mgmt.example.com
mount -t beegfs nodev /mnt/beegfs
对象存储
对象存储适合存储海量非结构化数据,具有无限扩展能力。
MinIO
MinIO 是高性能的开源对象存储,兼容 S3 API。
部署示例:
apiVersion: apps/v1
kind: Deployment
metadata:
name: minio
spec:
selector:
matchLabels:
app: minio
template:
metadata:
labels:
app: minio
spec:
containers:
- name: minio
image: minio/minio:latest
args:
- server
- /data
- --console-address
- ":9001"
ports:
- containerPort: 9000
- containerPort: 9001
volumeMounts:
- name: data
mountPath: /data
volumes:
- name: data
persistentVolumeClaim:
claimName: minio-pvc
Python 客户端:
from minio import Minio
client = Minio(
"minio.example.com:9000",
access_key="minioadmin",
secret_key="minioadmin",
secure=False
)
# 上传文件
client.fput_object("datasets", "train.parquet", "/local/train.parquet")
# 下载文件
client.fget_object("datasets", "train.parquet", "/local/download.parquet")
S3 兼容存储
大多数云厂商提供 S3 兼容的对象存储:
- AWS S3
- 阿里云 OSS
- 腾讯云 COS
- 华为云 OBS
数据缓存层
Alluxio
Alluxio 是分布式缓存系统,加速数据访问:
┌─────────────────────────────────────────────────────┐
│ 计算框架 │
│ Spark │ PyTorch │ TensorFlow │ Presto │
├─────────────────────────────────────────────────────┤
│ Alluxio │
│ 内存缓存 │ SSD 缓存 │
├─────────────────────────────────────────────────────┤
│ 底层存储 │
│ S3 │ HDFS │ GCS │ OSS │ NAS │
└─────────────────────────────────────────────────────┘
配置示例:
# alluxio-site.properties
alluxio.master.hostname=master
alluxio.underfs.address=s3://my-bucket
alluxio.worker.memory.size=64GB
alluxio.worker.tieredstore.levels=2
alluxio.worker.tieredstore.level0.alias=MEM
alluxio.worker.tieredstore.level0.path=/mnt/ramdisk
alluxio.worker.tieredstore.level1.alias=SSD
alluxio.worker.tieredstore.level1.path=/mnt/ssd
PyTorch 集成:
import torch
from torch.utils.data import DataLoader
# 通过 Alluxio 访问数据
dataset = torch.load("alluxio://master:19998/models/model.pt")
数据加载优化
数据加载是训练流程中容易被忽视但影响巨大的环节。一个配置不当的 DataLoader 可能导致 GPU 利用率低下,即使模型本身已经高度优化。
PyTorch DataLoader 配置详解
from torch.utils.data import DataLoader, Dataset
class FastDataset(Dataset):
def __init__(self, data_path):
self.data = self._load_data(data_path)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=8, # 多进程加载
pin_memory=True, # 锁页内存,加速 GPU 传输
prefetch_factor=2, # 预取批次
persistent_workers=True # 保持 worker 进程
)
关键参数详解:
| 参数 | 说明 | 推荐值 |
|---|---|---|
num_workers | 数据加载进程数 | CPU 核心数 / 2 到 CPU 核心数 |
pin_memory | 使用锁页内存 | True(GPU 训练时) |
prefetch_factor | 每个 worker 预取的批次 | 2-4 |
persistent_workers | 保持 worker 进程 | True(避免启动开销) |
drop_last | 丢弃不完整批次 | True(分布式训练时) |
num_workers 调优指南:
num_workers 的最佳值取决于多个因素:CPU 核心数、I/O 瓶颈程度、数据预处理复杂度。一个常用的经验公式:
num_workers = min(CPU核心数, 4 × GPU数)
如果数据预处理简单(如仅需解码图像),可以设置较低的 num_workers(如 4)。如果预处理复杂(如需要数据增强),则需要更高的 num_workers。
诊断数据加载瓶颈:
import time
# 测试数据加载速度
start = time.time()
for i, batch in enumerate(dataloader):
if i >= 100:
break
elapsed = time.time() - start
print(f"100 批次加载时间: {elapsed:.2f}s")
print(f"每批次平均时间: {elapsed/100*1000:.2f}ms")
# 对比 GPU 计算时间
# 如果加载时间 >> 计算时间,说明存在 I/O 瓶颈
数据格式选择
不同数据格式对加载性能有显著影响:
| 格式 | 读取速度 | 压缩率 | 随机访问 | 适用场景 |
|---|---|---|---|---|
| TFRecord | 快 | 中等 | 差 | TensorFlow 生态 |
| Parquet | 中等 | 高 | 好 | 结构化数据、表格数据 |
| HDF5 | 快 | 低 | 好 | 科学计算、大型数组 |
| WebDataset | 快 | 中等 | 差 | 大规模数据、流式处理 |
| NPZ/NPY | 最快 | 无 | 好 | NumPy 数组、小数据集 |
| LMDB | 快 | 无 | 极好 | 需要快速随机访问 |
选择建议:
- 图像数据:WebDataset(大规模)或 TFRecord(TensorFlow)
- 文本数据:Parquet(结构化)或 WebDataset(大规模)
- 表格数据:Parquet
- 科学数据:HDF5 或 NPY
- 需要随机访问:LMDB 或 Parquet
WebDataset 示例
WebDataset 使用 tar 文件格式,适合流式读取:
import webdataset as wds
dataset = wds.WebDataset(
"s3://bucket/dataset-{000000..000999}.tar"
).shuffle(1000).decode("pil").to_tuple("jpg;png", "json")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
检查点存储
训练检查点需要高吞吐量存储支持。一个完善的检查点系统不仅能防止训练中断导致的数据丢失,还能支持断点续训、实验复现和模型版本管理。
检查点策略
基本保存和加载:
import torch
import os
def save_checkpoint(model, optimizer, epoch, loss, path):
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
# 先写入临时文件,再原子重命名
temp_path = path + '.tmp'
torch.save(checkpoint, temp_path)
os.rename(temp_path, path)
def load_checkpoint(model, optimizer, path):
if os.path.exists(path):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch']
return 0
完整的检查点内容:
为了支持完整的断点续训,检查点应包含:
checkpoint = {
# 必需内容
'epoch': epoch,
'global_step': global_step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
# 推荐:支持精确恢复
'scheduler_state_dict': scheduler.state_dict(), # 学习率调度器
'random_state': torch.get_rng_state(), # PyTorch 随机状态
'cuda_random_state': torch.cuda.get_rng_state(), # CUDA 随机状态
'numpy_random_state': np.random.get_state(), # NumPy 随机状态
# 可选:训练监控
'loss': loss,
'best_loss': best_loss,
'config': config, # 训练配置
}
保留策略:
合理的检查点保留策略可以在存储空间和恢复能力之间取得平衡:
def save_with_rotation(checkpoint, save_dir, epoch, keep_last=3, keep_best=True):
"""保存检查点并清理旧文件"""
# 保存当前检查点
path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt')
torch.save(checkpoint, path)
# 更新 latest 链接
latest_path = os.path.join(save_dir, 'checkpoint_latest.pt')
if os.path.exists(latest_path):
os.remove(latest_path)
os.symlink(path, latest_path)
# 清理旧检查点
checkpoints = sorted(glob.glob(os.path.join(save_dir, 'checkpoint_epoch_*.pt')))
if len(checkpoints) > keep_last:
for old_ckpt in checkpoints[:-keep_last]:
# 如果是最优检查点,不删除
if 'best' not in old_ckpt:
os.remove(old_ckpt)
异步保存
使用后台线程保存检查点,避免阻塞训练:
import threading
import queue
class AsyncCheckpointSaver:
def __init__(self):
self.queue = queue.Queue()
self.thread = threading.Thread(target=self._worker, daemon=True)
self.thread.start()
def _worker(self):
while True:
model, path = self.queue.get()
torch.save(model.state_dict(), path)
def save(self, model, path):
self.queue.put((model, path))
saver = AsyncCheckpointSaver()
saver.save(model, "checkpoint.pt")
更完善的异步保存器:
class CheckpointManager:
def __init__(self, save_dir, max_to_keep=5):
self.save_dir = save_dir
self.max_to_keep = max_to_keep
self.queue = queue.Queue()
self.thread = threading.Thread(target=self._worker, daemon=True)
self.thread.start()
def _worker(self):
while True:
checkpoint, path = self.queue.get()
try:
# 先写临时文件
temp_path = path + '.tmp'
torch.save(checkpoint, temp_path)
os.rename(temp_path, path)
self._cleanup()
except Exception as e:
print(f"保存检查点失败: {e}")
def _cleanup(self):
"""保留最近的 N 个检查点"""
checkpoints = sorted(glob.glob(os.path.join(self.save_dir, 'checkpoint_*.pt')))
while len(checkpoints) > self.max_to_keep:
os.remove(checkpoints.pop(0))
def save(self, model, optimizer, epoch, **kwargs):
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
**kwargs
}
path = os.path.join(self.save_dir, f'checkpoint_epoch_{epoch}.pt')
self.queue.put((checkpoint, path))
def wait(self):
"""等待所有保存任务完成"""
self.queue.join()
分布式训练中的检查点
在分布式训练中,只需要主进程保存检查点:
import torch.distributed as dist
def save_checkpoint_ddp(model, optimizer, epoch, save_dir):
# 只在主进程保存
if dist.get_rank() == 0:
# 对于 DDP 模型,需要使用 module 属性
state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
checkpoint = {
'epoch': epoch,
'model_state_dict': state_dict,
'optimizer_state_dict': optimizer.state_dict(),
}
path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt')
torch.save(checkpoint, path)
# 同步所有进程
dist.barrier()
FSDP 检查点保存:
FSDP 的检查点保存有特殊要求,因为模型参数是分片的:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
# 设置状态字典类型
FSDP.set_state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(rank0_only=True)
)
# 保存完整状态字典(只在 rank 0 保存)
state_dict = model.state_dict()
if dist.get_rank() == 0:
torch.save(state_dict, "checkpoint.pt")
存储性能优化
I/O 优化策略
- 数据预取:提前加载数据到内存或 SSD
- 数据分片:按节点分片,减少跨节点读取
- 压缩传输:减少网络传输量
- 缓存热数据:将频繁访问的数据缓存
性能监控
# 监控 I/O 性能
iostat -x 1
# 监控 NFS 性能
nfsstat -c
# 监控网络吞吐
iftop
常见问题
| 问题 | 症状 | 解决方案 |
|---|---|---|
| I/O 瓶颈 | GPU 利用率低,等待数据 | 增加缓存、优化数据格式 |
| 存储满 | 写入失败 | 清理数据、扩容 |
| 网络拥塞 | 读取延迟高 | 增加带宽、使用本地缓存 |
| 元数据瓶颈 | 小文件操作慢 | 合并小文件、使用对象存储 |
数据生命周期管理
数据分层
热数据(频繁访问)
↓ 30 天后
温数据(偶尔访问)
↓ 90 天后
冷数据(很少访问)
自动化策略
import boto3
s3 = boto3.client('s3')
def move_to_cold_storage(bucket, prefix, days=90):
"""将旧数据移动到冷存储"""
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
for obj in response.get('Contents', []):
age = (datetime.now() - obj['LastModified']).days
if age > days:
s3.copy_object(
Bucket='cold-storage-bucket',
Key=obj['Key'],
CopySource={'Bucket': bucket, 'Key': obj['Key']}
)
s3.delete_object(Bucket=bucket, Key=obj['Key'])
小结
存储系统是 AI 基础设施的关键组件。选择合适的存储架构、优化数据访问模式、实施有效的数据管理策略,可以显著提升训练效率并降低成本。在实际应用中,通常需要组合使用多种存储技术,构建分层的存储架构。