跳到主要内容

存储系统

存储系统是 AI 基础设施的重要组成部分。训练数据集、模型检查点、日志等都需要高效的存储支持。

AI 存储的挑战

AI 工作负载对存储系统提出了特殊要求:

挑战说明
数据规模训练数据集可达 TB 到 PB 级别
读取模式大量小文件随机读取、大文件顺序读取
并发访问数百个 GPU 同时读取数据
检查点写入周期性写入大文件,要求高吞吐
成本控制大量数据存储成本高昂

存储层次架构

一个典型的 AI 存储架构包含多个层次:

┌─────────────────────────────────────────────────────────┐
│ 应用层 │
│ PyTorch DataLoader │ TensorFlow tf.data │
├─────────────────────────────────────────────────────────┤
│ 缓存层 │
│ Alluxio │ Redis │ 本地 SSD │
├─────────────────────────────────────────────────────────┤
│ 文件系统层 │
│ NFS │ Lustre │ GPFS │ CephFS │
├─────────────────────────────────────────────────────────┤
│ 对象存储层 │
│ S3 │ MinIO │ Ceph RGW │
├─────────────────────────────────────────────────────────┤
│ 块存储层 │
│ SAN │ 本地 NVMe │ 云盘 │
└─────────────────────────────────────────────────────────┘

并行文件系统

Lustre

Lustre 是高性能计算领域广泛使用的并行文件系统。

架构组件

  • MDS(Metadata Server):管理文件元数据
  • OSS(Object Storage Server):存储实际数据
  • Client:客户端挂载点

特点

  • 高吞吐量:可达数百 GB/s
  • POSIX 兼容:支持标准文件操作
  • 可扩展:支持数千个客户端

配置示例

# 客户端挂载
mount -t lustre mds.example.com:/fs /mnt/lustre

# 性能调优
lctl set_param lru_cache=1000000

GPFS(IBM Storage Scale)

IBM 开发的高性能并行文件系统:

特点

  • 企业级可靠性
  • 优秀的元数据性能
  • 支持文件快照和复制

BeeGFS

开源的并行文件系统,部署简单:

# 安装服务端
apt-get install beegfs-mgmtd beegfs-meta beegfs-storage

# 配置存储目标
/opt/beegfs/sbin/beegfs-setup-storage -p /data/beegfs -s 1

# 客户端挂载
/opt/beegfs/sbin/beegfs-setup-client -m mgmt.example.com
mount -t beegfs nodev /mnt/beegfs

对象存储

对象存储适合存储海量非结构化数据,具有无限扩展能力。

MinIO

MinIO 是高性能的开源对象存储,兼容 S3 API。

部署示例

apiVersion: apps/v1
kind: Deployment
metadata:
name: minio
spec:
selector:
matchLabels:
app: minio
template:
metadata:
labels:
app: minio
spec:
containers:
- name: minio
image: minio/minio:latest
args:
- server
- /data
- --console-address
- ":9001"
ports:
- containerPort: 9000
- containerPort: 9001
volumeMounts:
- name: data
mountPath: /data
volumes:
- name: data
persistentVolumeClaim:
claimName: minio-pvc

Python 客户端

from minio import Minio

client = Minio(
"minio.example.com:9000",
access_key="minioadmin",
secret_key="minioadmin",
secure=False
)

# 上传文件
client.fput_object("datasets", "train.parquet", "/local/train.parquet")

# 下载文件
client.fget_object("datasets", "train.parquet", "/local/download.parquet")

S3 兼容存储

大多数云厂商提供 S3 兼容的对象存储:

  • AWS S3
  • 阿里云 OSS
  • 腾讯云 COS
  • 华为云 OBS

数据缓存层

Alluxio

Alluxio 是分布式缓存系统,加速数据访问:

┌─────────────────────────────────────────────────────┐
│ 计算框架 │
│ Spark │ PyTorch │ TensorFlow │ Presto │
├─────────────────────────────────────────────────────┤
│ Alluxio │
│ 内存缓存 │ SSD 缓存 │
├─────────────────────────────────────────────────────┤
│ 底层存储 │
│ S3 │ HDFS │ GCS │ OSS │ NAS │
└─────────────────────────────────────────────────────┘

配置示例

# alluxio-site.properties
alluxio.master.hostname=master
alluxio.underfs.address=s3://my-bucket
alluxio.worker.memory.size=64GB
alluxio.worker.tieredstore.levels=2
alluxio.worker.tieredstore.level0.alias=MEM
alluxio.worker.tieredstore.level0.path=/mnt/ramdisk
alluxio.worker.tieredstore.level1.alias=SSD
alluxio.worker.tieredstore.level1.path=/mnt/ssd

PyTorch 集成

import torch
from torch.utils.data import DataLoader

# 通过 Alluxio 访问数据
dataset = torch.load("alluxio://master:19998/models/model.pt")

数据加载优化

PyTorch DataLoader

from torch.utils.data import DataLoader, Dataset

class FastDataset(Dataset):
def __init__(self, data_path):
self.data = self._load_data(data_path)

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
return self.data[idx]

dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=8, # 多进程加载
pin_memory=True, # 锁页内存,加速 GPU 传输
prefetch_factor=2, # 预取批次
persistent_workers=True # 保持 worker 进程
)

数据格式选择

格式读取速度压缩率适用场景
TFRecord中等TensorFlow
Parquet中等结构化数据
HDF5科学计算
WebDataset中等大规模数据

WebDataset 示例

WebDataset 使用 tar 文件格式,适合流式读取:

import webdataset as wds

dataset = wds.WebDataset(
"s3://bucket/dataset-{000000..000999}.tar"
).shuffle(1000).decode("pil").to_tuple("jpg;png", "json")

dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)

检查点存储

训练检查点需要高吞吐量存储支持。

检查点策略

import torch
import os

def save_checkpoint(model, optimizer, epoch, loss, path):
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}

# 先写入临时文件,再原子重命名
temp_path = path + '.tmp'
torch.save(checkpoint, temp_path)
os.rename(temp_path, path)

def load_checkpoint(model, optimizer, path):
if os.path.exists(path):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch']
return 0

异步保存

使用后台线程保存检查点,避免阻塞训练:

import threading
import queue

class AsyncCheckpointSaver:
def __init__(self):
self.queue = queue.Queue()
self.thread = threading.Thread(target=self._worker, daemon=True)
self.thread.start()

def _worker(self):
while True:
model, path = self.queue.get()
torch.save(model.state_dict(), path)

def save(self, model, path):
self.queue.put((model, path))

saver = AsyncCheckpointSaver()
saver.save(model, "checkpoint.pt")

存储性能优化

I/O 优化策略

  1. 数据预取:提前加载数据到内存或 SSD
  2. 数据分片:按节点分片,减少跨节点读取
  3. 压缩传输:减少网络传输量
  4. 缓存热数据:将频繁访问的数据缓存

性能监控

# 监控 I/O 性能
iostat -x 1

# 监控 NFS 性能
nfsstat -c

# 监控网络吞吐
iftop

常见问题

问题症状解决方案
I/O 瓶颈GPU 利用率低,等待数据增加缓存、优化数据格式
存储满写入失败清理数据、扩容
网络拥塞读取延迟高增加带宽、使用本地缓存
元数据瓶颈小文件操作慢合并小文件、使用对象存储

数据生命周期管理

数据分层

热数据(频繁访问)
↓ 30 天后
温数据(偶尔访问)
↓ 90 天后
冷数据(很少访问)

自动化策略

import boto3

s3 = boto3.client('s3')

def move_to_cold_storage(bucket, prefix, days=90):
"""将旧数据移动到冷存储"""
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)

for obj in response.get('Contents', []):
age = (datetime.now() - obj['LastModified']).days
if age > days:
s3.copy_object(
Bucket='cold-storage-bucket',
Key=obj['Key'],
CopySource={'Bucket': bucket, 'Key': obj['Key']}
)
s3.delete_object(Bucket=bucket, Key=obj['Key'])

小结

存储系统是 AI 基础设施的关键组件。选择合适的存储架构、优化数据访问模式、实施有效的数据管理策略,可以显著提升训练效率并降低成本。在实际应用中,通常需要组合使用多种存储技术,构建分层的存储架构。

参考资料