跳到主要内容

存储系统

存储系统是 AI 基础设施的重要组成部分。训练数据集、模型检查点、日志等都需要高效的存储支持。

AI 存储的挑战

AI 工作负载对存储系统提出了特殊要求:

挑战说明
数据规模训练数据集可达 TB 到 PB 级别
读取模式大量小文件随机读取、大文件顺序读取
并发访问数百个 GPU 同时读取数据
检查点写入周期性写入大文件,要求高吞吐
成本控制大量数据存储成本高昂

存储层次架构

一个典型的 AI 存储架构包含多个层次:

┌─────────────────────────────────────────────────────────┐
│ 应用层 │
│ PyTorch DataLoader │ TensorFlow tf.data │
├─────────────────────────────────────────────────────────┤
│ 缓存层 │
│ Alluxio │ Redis │ 本地 SSD │
├─────────────────────────────────────────────────────────┤
│ 文件系统层 │
│ NFS │ Lustre │ GPFS │ CephFS │
├─────────────────────────────────────────────────────────┤
│ 对象存储层 │
│ S3 │ MinIO │ Ceph RGW │
├─────────────────────────────────────────────────────────┤
│ 块存储层 │
│ SAN │ 本地 NVMe │ 云盘 │
└─────────────────────────────────────────────────────────┘

并行文件系统

Lustre

Lustre 是高性能计算领域广泛使用的并行文件系统。

架构组件

  • MDS(Metadata Server):管理文件元数据
  • OSS(Object Storage Server):存储实际数据
  • Client:客户端挂载点

特点

  • 高吞吐量:可达数百 GB/s
  • POSIX 兼容:支持标准文件操作
  • 可扩展:支持数千个客户端

配置示例

# 客户端挂载
mount -t lustre mds.example.com:/fs /mnt/lustre

# 性能调优
lctl set_param lru_cache=1000000

GPFS(IBM Storage Scale)

IBM 开发的高性能并行文件系统:

特点

  • 企业级可靠性
  • 优秀的元数据性能
  • 支持文件快照和复制

BeeGFS

开源的并行文件系统,部署简单:

# 安装服务端
apt-get install beegfs-mgmtd beegfs-meta beegfs-storage

# 配置存储目标
/opt/beegfs/sbin/beegfs-setup-storage -p /data/beegfs -s 1

# 客户端挂载
/opt/beegfs/sbin/beegfs-setup-client -m mgmt.example.com
mount -t beegfs nodev /mnt/beegfs

对象存储

对象存储适合存储海量非结构化数据,具有无限扩展能力。

MinIO

MinIO 是高性能的开源对象存储,兼容 S3 API。

部署示例

apiVersion: apps/v1
kind: Deployment
metadata:
name: minio
spec:
selector:
matchLabels:
app: minio
template:
metadata:
labels:
app: minio
spec:
containers:
- name: minio
image: minio/minio:latest
args:
- server
- /data
- --console-address
- ":9001"
ports:
- containerPort: 9000
- containerPort: 9001
volumeMounts:
- name: data
mountPath: /data
volumes:
- name: data
persistentVolumeClaim:
claimName: minio-pvc

Python 客户端

from minio import Minio

client = Minio(
"minio.example.com:9000",
access_key="minioadmin",
secret_key="minioadmin",
secure=False
)

# 上传文件
client.fput_object("datasets", "train.parquet", "/local/train.parquet")

# 下载文件
client.fget_object("datasets", "train.parquet", "/local/download.parquet")

S3 兼容存储

大多数云厂商提供 S3 兼容的对象存储:

  • AWS S3
  • 阿里云 OSS
  • 腾讯云 COS
  • 华为云 OBS

数据缓存层

Alluxio

Alluxio 是分布式缓存系统,加速数据访问:

┌─────────────────────────────────────────────────────┐
│ 计算框架 │
│ Spark │ PyTorch │ TensorFlow │ Presto │
├─────────────────────────────────────────────────────┤
│ Alluxio │
│ 内存缓存 │ SSD 缓存 │
├─────────────────────────────────────────────────────┤
│ 底层存储 │
│ S3 │ HDFS │ GCS │ OSS │ NAS │
└─────────────────────────────────────────────────────┘

配置示例

# alluxio-site.properties
alluxio.master.hostname=master
alluxio.underfs.address=s3://my-bucket
alluxio.worker.memory.size=64GB
alluxio.worker.tieredstore.levels=2
alluxio.worker.tieredstore.level0.alias=MEM
alluxio.worker.tieredstore.level0.path=/mnt/ramdisk
alluxio.worker.tieredstore.level1.alias=SSD
alluxio.worker.tieredstore.level1.path=/mnt/ssd

PyTorch 集成

import torch
from torch.utils.data import DataLoader

# 通过 Alluxio 访问数据
dataset = torch.load("alluxio://master:19998/models/model.pt")

数据加载优化

数据加载是训练流程中容易被忽视但影响巨大的环节。一个配置不当的 DataLoader 可能导致 GPU 利用率低下,即使模型本身已经高度优化。

PyTorch DataLoader 配置详解

from torch.utils.data import DataLoader, Dataset

class FastDataset(Dataset):
def __init__(self, data_path):
self.data = self._load_data(data_path)

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
return self.data[idx]

dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=8, # 多进程加载
pin_memory=True, # 锁页内存,加速 GPU 传输
prefetch_factor=2, # 预取批次
persistent_workers=True # 保持 worker 进程
)

关键参数详解

参数说明推荐值
num_workers数据加载进程数CPU 核心数 / 2 到 CPU 核心数
pin_memory使用锁页内存True(GPU 训练时)
prefetch_factor每个 worker 预取的批次2-4
persistent_workers保持 worker 进程True(避免启动开销)
drop_last丢弃不完整批次True(分布式训练时)

num_workers 调优指南

num_workers 的最佳值取决于多个因素:CPU 核心数、I/O 瓶颈程度、数据预处理复杂度。一个常用的经验公式:

num_workers = min(CPU核心数, 4 × GPU数)

如果数据预处理简单(如仅需解码图像),可以设置较低的 num_workers(如 4)。如果预处理复杂(如需要数据增强),则需要更高的 num_workers

诊断数据加载瓶颈

import time

# 测试数据加载速度
start = time.time()
for i, batch in enumerate(dataloader):
if i >= 100:
break
elapsed = time.time() - start
print(f"100 批次加载时间: {elapsed:.2f}s")
print(f"每批次平均时间: {elapsed/100*1000:.2f}ms")

# 对比 GPU 计算时间
# 如果加载时间 >> 计算时间,说明存在 I/O 瓶颈

数据格式选择

不同数据格式对加载性能有显著影响:

格式读取速度压缩率随机访问适用场景
TFRecord中等TensorFlow 生态
Parquet中等结构化数据、表格数据
HDF5科学计算、大型数组
WebDataset中等大规模数据、流式处理
NPZ/NPY最快NumPy 数组、小数据集
LMDB极好需要快速随机访问

选择建议

  • 图像数据:WebDataset(大规模)或 TFRecord(TensorFlow)
  • 文本数据:Parquet(结构化)或 WebDataset(大规模)
  • 表格数据:Parquet
  • 科学数据:HDF5 或 NPY
  • 需要随机访问:LMDB 或 Parquet

WebDataset 示例

WebDataset 使用 tar 文件格式,适合流式读取:

import webdataset as wds

dataset = wds.WebDataset(
"s3://bucket/dataset-{000000..000999}.tar"
).shuffle(1000).decode("pil").to_tuple("jpg;png", "json")

dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)

检查点存储

训练检查点需要高吞吐量存储支持。一个完善的检查点系统不仅能防止训练中断导致的数据丢失,还能支持断点续训、实验复现和模型版本管理。

检查点策略

基本保存和加载

import torch
import os

def save_checkpoint(model, optimizer, epoch, loss, path):
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}

# 先写入临时文件,再原子重命名
temp_path = path + '.tmp'
torch.save(checkpoint, temp_path)
os.rename(temp_path, path)

def load_checkpoint(model, optimizer, path):
if os.path.exists(path):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch']
return 0

完整的检查点内容

为了支持完整的断点续训,检查点应包含:

checkpoint = {
# 必需内容
'epoch': epoch,
'global_step': global_step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),

# 推荐:支持精确恢复
'scheduler_state_dict': scheduler.state_dict(), # 学习率调度器
'random_state': torch.get_rng_state(), # PyTorch 随机状态
'cuda_random_state': torch.cuda.get_rng_state(), # CUDA 随机状态
'numpy_random_state': np.random.get_state(), # NumPy 随机状态

# 可选:训练监控
'loss': loss,
'best_loss': best_loss,
'config': config, # 训练配置
}

保留策略

合理的检查点保留策略可以在存储空间和恢复能力之间取得平衡:

def save_with_rotation(checkpoint, save_dir, epoch, keep_last=3, keep_best=True):
"""保存检查点并清理旧文件"""
# 保存当前检查点
path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt')
torch.save(checkpoint, path)

# 更新 latest 链接
latest_path = os.path.join(save_dir, 'checkpoint_latest.pt')
if os.path.exists(latest_path):
os.remove(latest_path)
os.symlink(path, latest_path)

# 清理旧检查点
checkpoints = sorted(glob.glob(os.path.join(save_dir, 'checkpoint_epoch_*.pt')))
if len(checkpoints) > keep_last:
for old_ckpt in checkpoints[:-keep_last]:
# 如果是最优检查点,不删除
if 'best' not in old_ckpt:
os.remove(old_ckpt)

异步保存

使用后台线程保存检查点,避免阻塞训练:

import threading
import queue

class AsyncCheckpointSaver:
def __init__(self):
self.queue = queue.Queue()
self.thread = threading.Thread(target=self._worker, daemon=True)
self.thread.start()

def _worker(self):
while True:
model, path = self.queue.get()
torch.save(model.state_dict(), path)

def save(self, model, path):
self.queue.put((model, path))

saver = AsyncCheckpointSaver()
saver.save(model, "checkpoint.pt")

更完善的异步保存器

class CheckpointManager:
def __init__(self, save_dir, max_to_keep=5):
self.save_dir = save_dir
self.max_to_keep = max_to_keep
self.queue = queue.Queue()
self.thread = threading.Thread(target=self._worker, daemon=True)
self.thread.start()

def _worker(self):
while True:
checkpoint, path = self.queue.get()
try:
# 先写临时文件
temp_path = path + '.tmp'
torch.save(checkpoint, temp_path)
os.rename(temp_path, path)
self._cleanup()
except Exception as e:
print(f"保存检查点失败: {e}")

def _cleanup(self):
"""保留最近的 N 个检查点"""
checkpoints = sorted(glob.glob(os.path.join(self.save_dir, 'checkpoint_*.pt')))
while len(checkpoints) > self.max_to_keep:
os.remove(checkpoints.pop(0))

def save(self, model, optimizer, epoch, **kwargs):
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
**kwargs
}
path = os.path.join(self.save_dir, f'checkpoint_epoch_{epoch}.pt')
self.queue.put((checkpoint, path))

def wait(self):
"""等待所有保存任务完成"""
self.queue.join()

分布式训练中的检查点

在分布式训练中,只需要主进程保存检查点:

import torch.distributed as dist

def save_checkpoint_ddp(model, optimizer, epoch, save_dir):
# 只在主进程保存
if dist.get_rank() == 0:
# 对于 DDP 模型,需要使用 module 属性
state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()

checkpoint = {
'epoch': epoch,
'model_state_dict': state_dict,
'optimizer_state_dict': optimizer.state_dict(),
}

path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pt')
torch.save(checkpoint, path)

# 同步所有进程
dist.barrier()

FSDP 检查点保存

FSDP 的检查点保存有特殊要求,因为模型参数是分片的:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig

# 设置状态字典类型
FSDP.set_state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(rank0_only=True)
)

# 保存完整状态字典(只在 rank 0 保存)
state_dict = model.state_dict()
if dist.get_rank() == 0:
torch.save(state_dict, "checkpoint.pt")

存储性能优化

I/O 优化策略

  1. 数据预取:提前加载数据到内存或 SSD
  2. 数据分片:按节点分片,减少跨节点读取
  3. 压缩传输:减少网络传输量
  4. 缓存热数据:将频繁访问的数据缓存

性能监控

# 监控 I/O 性能
iostat -x 1

# 监控 NFS 性能
nfsstat -c

# 监控网络吞吐
iftop

常见问题

问题症状解决方案
I/O 瓶颈GPU 利用率低,等待数据增加缓存、优化数据格式
存储满写入失败清理数据、扩容
网络拥塞读取延迟高增加带宽、使用本地缓存
元数据瓶颈小文件操作慢合并小文件、使用对象存储

数据生命周期管理

数据分层

热数据(频繁访问)
↓ 30 天后
温数据(偶尔访问)
↓ 90 天后
冷数据(很少访问)

自动化策略

import boto3

s3 = boto3.client('s3')

def move_to_cold_storage(bucket, prefix, days=90):
"""将旧数据移动到冷存储"""
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)

for obj in response.get('Contents', []):
age = (datetime.now() - obj['LastModified']).days
if age > days:
s3.copy_object(
Bucket='cold-storage-bucket',
Key=obj['Key'],
CopySource={'Bucket': bucket, 'Key': obj['Key']}
)
s3.delete_object(Bucket=bucket, Key=obj['Key'])

小结

存储系统是 AI 基础设施的关键组件。选择合适的存储架构、优化数据访问模式、实施有效的数据管理策略,可以显著提升训练效率并降低成本。在实际应用中,通常需要组合使用多种存储技术,构建分层的存储架构。

参考资料