存储系统
存储系统是 AI 基础设施的重要组成部分。训练数据集、模型检查点、日志等都需要高效的存储支持。
AI 存储的挑战
AI 工作负载对存储系统提出了特殊要求:
| 挑战 | 说明 |
|---|---|
| 数据规模 | 训练数据集可达 TB 到 PB 级别 |
| 读取模式 | 大量小文件随机读取、大文件顺序读取 |
| 并发访问 | 数百个 GPU 同时读取数据 |
| 检查点写入 | 周期性写入大文件,要求高吞吐 |
| 成本控制 | 大量数据存储成本高昂 |
存储层次架构
一个典型的 AI 存储架构包含多个层次:
┌─────────────────────────────────────────────────────────┐
│ 应用层 │
│ PyTorch DataLoader │ TensorFlow tf.data │
├─────────────────────────────────────────────────────────┤
│ 缓存层 │
│ Alluxio │ Redis │ 本地 SSD │
├─────────────────────────────────────────────────────────┤
│ 文件系统层 │
│ NFS │ Lustre │ GPFS │ CephFS │
├─────────────────────────────────────────────────────────┤
│ 对象存储层 │
│ S3 │ MinIO │ Ceph RGW │
├─────────────────────────────────────────────────────────┤
│ 块存储层 │
│ SAN │ 本地 NVMe │ 云盘 │
└─────────────────────────────────────────────────────────┘
并行文件系统
Lustre
Lustre 是高性能计算领域广泛使用的并行文件系统。
架构组件:
- MDS(Metadata Server):管理文件元数据
- OSS(Object Storage Server):存储实际数据
- Client:客户端挂载点
特点:
- 高吞吐量:可达数百 GB/s
- POSIX 兼容:支持标准文件操作
- 可扩展:支持数千个客户端
配置示例:
# 客户端挂载
mount -t lustre mds.example.com:/fs /mnt/lustre
# 性能调优
lctl set_param lru_cache=1000000
GPFS(IBM Storage Scale)
IBM 开发的高性能并行文件系统:
特点:
- 企业级可靠性
- 优秀的元数据性能
- 支持文件快照和复制
BeeGFS
开源的并行文件系统,部署简单:
# 安装服务端
apt-get install beegfs-mgmtd beegfs-meta beegfs-storage
# 配置存储目标
/opt/beegfs/sbin/beegfs-setup-storage -p /data/beegfs -s 1
# 客户端挂载
/opt/beegfs/sbin/beegfs-setup-client -m mgmt.example.com
mount -t beegfs nodev /mnt/beegfs
对象存储
对象存储适合存储海量非结构化数据,具有无限扩展能力。
MinIO
MinIO 是高性能的开源对象存储,兼容 S3 API。
部署示例:
apiVersion: apps/v1
kind: Deployment
metadata:
name: minio
spec:
selector:
matchLabels:
app: minio
template:
metadata:
labels:
app: minio
spec:
containers:
- name: minio
image: minio/minio:latest
args:
- server
- /data
- --console-address
- ":9001"
ports:
- containerPort: 9000
- containerPort: 9001
volumeMounts:
- name: data
mountPath: /data
volumes:
- name: data
persistentVolumeClaim:
claimName: minio-pvc
Python 客户端:
from minio import Minio
client = Minio(
"minio.example.com:9000",
access_key="minioadmin",
secret_key="minioadmin",
secure=False
)
# 上传文件
client.fput_object("datasets", "train.parquet", "/local/train.parquet")
# 下载文件
client.fget_object("datasets", "train.parquet", "/local/download.parquet")
S3 兼容存储
大多数云厂商提供 S3 兼容的对象存储:
- AWS S3
- 阿里云 OSS
- 腾讯云 COS
- 华为云 OBS
数据缓存层
Alluxio
Alluxio 是分布式缓存系统,加速数据访问:
┌─────────────────────────────────────────────────────┐
│ 计算框架 │
│ Spark │ PyTorch │ TensorFlow │ Presto │
├─────────────────────────────────────────────────────┤
│ Alluxio │
│ 内存缓存 │ SSD 缓存 │
├─────────────────────────────────────────────────────┤
│ 底层存储 │
│ S3 │ HDFS │ GCS │ OSS │ NAS │
└─────────────────────────────────────────────────────┘
配置示例:
# alluxio-site.properties
alluxio.master.hostname=master
alluxio.underfs.address=s3://my-bucket
alluxio.worker.memory.size=64GB
alluxio.worker.tieredstore.levels=2
alluxio.worker.tieredstore.level0.alias=MEM
alluxio.worker.tieredstore.level0.path=/mnt/ramdisk
alluxio.worker.tieredstore.level1.alias=SSD
alluxio.worker.tieredstore.level1.path=/mnt/ssd
PyTorch 集成:
import torch
from torch.utils.data import DataLoader
# 通过 Alluxio 访问数据
dataset = torch.load("alluxio://master:19998/models/model.pt")
数据加载优化
PyTorch DataLoader
from torch.utils.data import DataLoader, Dataset
class FastDataset(Dataset):
def __init__(self, data_path):
self.data = self._load_data(data_path)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=8, # 多进程加载
pin_memory=True, # 锁页内存,加速 GPU 传输
prefetch_factor=2, # 预取批次
persistent_workers=True # 保持 worker 进程
)
数据格式选择
| 格式 | 读取速度 | 压缩率 | 适用场景 |
|---|---|---|---|
| TFRecord | 快 | 中等 | TensorFlow |
| Parquet | 中等 | 高 | 结构化数据 |
| HDF5 | 快 | 低 | 科学计算 |
| WebDataset | 快 | 中等 | 大规模数据 |
WebDataset 示例
WebDataset 使用 tar 文件格式,适合流式读取:
import webdataset as wds
dataset = wds.WebDataset(
"s3://bucket/dataset-{000000..000999}.tar"
).shuffle(1000).decode("pil").to_tuple("jpg;png", "json")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
检查点存储
训练检查点需要高吞吐量存储支持。
检查点策略
import torch
import os
def save_checkpoint(model, optimizer, epoch, loss, path):
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
# 先写入临时文件,再原子重命名
temp_path = path + '.tmp'
torch.save(checkpoint, temp_path)
os.rename(temp_path, path)
def load_checkpoint(model, optimizer, path):
if os.path.exists(path):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch']
return 0
异步保存
使用后台线程保存检查点,避免阻塞训练:
import threading
import queue
class AsyncCheckpointSaver:
def __init__(self):
self.queue = queue.Queue()
self.thread = threading.Thread(target=self._worker, daemon=True)
self.thread.start()
def _worker(self):
while True:
model, path = self.queue.get()
torch.save(model.state_dict(), path)
def save(self, model, path):
self.queue.put((model, path))
saver = AsyncCheckpointSaver()
saver.save(model, "checkpoint.pt")
存储性能优化
I/O 优化策略
- 数据预取:提前加载数据到内存或 SSD
- 数据分片:按节点分片,减少跨节点读取
- 压缩传输:减少网络传输量
- 缓存热数据:将频繁访问的数据缓存
性能监控
# 监控 I/O 性能
iostat -x 1
# 监控 NFS 性能
nfsstat -c
# 监控网络吞吐
iftop
常见问题
| 问题 | 症状 | 解决方案 |
|---|---|---|
| I/O 瓶颈 | GPU 利用率低,等待数据 | 增加缓存、优化数据格式 |
| 存储满 | 写入失败 | 清理数据、扩容 |
| 网络拥塞 | 读取延迟高 | 增加带宽、使用本地缓存 |
| 元数据瓶颈 | 小文件操作慢 | 合并小文件、使用对象存储 |
数据生命周期管理
数据分层
热数据(频繁访问)
↓ 30 天后
温数据(偶尔访问)
↓ 90 天后
冷数据(很少访问)
自动化策略
import boto3
s3 = boto3.client('s3')
def move_to_cold_storage(bucket, prefix, days=90):
"""将旧数据移动到冷存储"""
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
for obj in response.get('Contents', []):
age = (datetime.now() - obj['LastModified']).days
if age > days:
s3.copy_object(
Bucket='cold-storage-bucket',
Key=obj['Key'],
CopySource={'Bucket': bucket, 'Key': obj['Key']}
)
s3.delete_object(Bucket=bucket, Key=obj['Key'])
小结
存储系统是 AI 基础设施的关键组件。选择合适的存储架构、优化数据访问模式、实施有效的数据管理策略,可以显著提升训练效率并降低成本。在实际应用中,通常需要组合使用多种存储技术,构建分层的存储架构。