数据集和数据加载器
本章将介绍 PyTorch 的数据加载和处理工具,包括 Dataset、DataLoader 和数据变换(Transforms)。
Dataset 基础
PyTorch 提供了 torch.utils.data.Dataset 抽象类来表示数据集。
内置数据集
PyTorch 内置了常用的数据集,可以直接使用:
import torchvision
from torchvision import datasets
# MNIST 手写数字数据集
mnist_train = datasets.MNIST(
root='./data', # 数据存储路径
train=True, # 训练集
download=True, # 自动下载
transform=None # 数据变换
)
mnist_test = datasets.MNIST(
root='./data',
train=False,
download=True
)
# 访问数据
image, label = mnist_train[0]
print(f"图像形状: {image.size()}") # PIL.Image.Size (28, 28)
print(f"标签: {label}") # 5
# CIFAR-10 数据集
cifar10_train = datasets.CIFAR10(
root='./data',
train=True,
download=True
)
# Fashion-MNIST
fashion_mnist = datasets.FashionMNIST(
root='./data',
train=True,
download=True
)
常用内置数据集
| 数据集 | 说明 | 类别数 | 图像大小 |
|---|---|---|---|
| MNIST | 手写数字 | 10 | 28×28 |
| Fashion-MNIST | 服装分类 | 10 | 28×28 |
| CIFAR-10 | 物体分类 | 10 | 32×32 |
| CIFAR-100 | 物体分类 | 100 | 32×32 |
| ImageNet | 大规模图像 | 1000 | 可变 |
| COCO | 目标检测 | 80 | 可变 |
自定义 Dataset
当使用自己的数据时,需要创建自定义的 Dataset 类。
基本结构
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data, labels, transform=None):
"""
初始化数据集
Args:
data: 数据(如文件路径列表或numpy数组)
labels: 标签
transform: 可选的数据变换
"""
self.data = data
self.labels = labels
self.transform = transform
def __len__(self):
"""返回数据集大小"""
return len(self.data)
def __getitem__(self, idx):
"""
获取单个样本
Args:
idx: 索引
Returns:
(image, label): 图像和标签元组
"""
# 获取数据
image = self.data[idx]
label = self.labels[idx]
# 应用变换
if self.transform:
image = self.transform(image)
return image, label
完整示例:从文件夹加载图像
import os
from PIL import Image
from torch.utils.data import Dataset
class ImageFolderDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.classes = sorted(os.listdir(root_dir))
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
# 收集所有图像路径和标签
self.samples = []
for class_name in self.classes:
class_dir = os.path.join(root_dir, class_name)
if os.path.isdir(class_dir):
for img_name in os.listdir(class_dir):
if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(class_dir, img_name)
self.samples.append((img_path, self.class_to_idx[class_name]))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label
DataLoader
DataLoader 提供了批量加载、多进程加载、数据打乱等功能。
基本使用
from torch.utils.data import DataLoader
# 创建 DataLoader
train_loader = DataLoader(
dataset=train_dataset, # 数据集
batch_size=32, # 批量大小
shuffle=True, # 是否打乱
num_workers=4, # 工作进程数(并行加载)
pin_memory=True, # 固定内存(加速GPU传输)
drop_last=True # 是否丢弃最后一个不完整的batch
)
# 迭代数据
for batch_idx, (images, labels) in enumerate(train_loader):
print(f"Batch {batch_idx}:")
print(f" 图像形状: {images.shape}") # [batch_size, channels, height, width]
print(f" 标签形状: {labels.shape}") # [batch_size]
print(f" 标签值: {labels[:5]}") # 前5个标签
# 训练步骤...
break # 只演示第一个batch
参数详解
| 参数 | 说明 | 建议 |
|---|---|---|
batch_size | 每个batch的样本数 | 根据GPU内存调整 |
shuffle | 是否打乱数据 | 训练时True,验证时False |
num_workers | 数据加载进程数 | CPU核心数-1,Windows建议0 |
pin_memory | 固定内存加速传输 | GPU训练时设为True |
drop_last | 丢弃不完整batch | True保证batch大小一致 |
collate_fn | 自定义批处理函数 | 处理不同长度数据 |
多 GPU 加载
# 使用 DistributedSampler 实现多 GPU 数据加载
from torch.utils.data.distributed import DistributedSampler
train_sampler = DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
shuffle=True
)
train_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=train_sampler,
num_workers=num_workers,
pin_memory=True
)
数据变换 (Transforms)
数据变换是预处理图像数据的关键步骤。
常用变换
from torchvision import transforms
# 组合多个变换
train_transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整大小
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomRotation(15), # 随机旋转
transforms.ColorJitter(brightness=0.2), # 颜色抖动
transforms.ToTensor(), # 转为张量(归一化到[0,1])
transforms.Normalize( # 标准化
mean=[0.485, 0.456, 0.406], # ImageNet均值
std=[0.229, 0.224, 0.225] # ImageNet标准差
)
])
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
变换类型分类
几何变换
# 调整大小
transforms.Resize((224, 224)) # 固定大小
transforms.Resize(256) # 短边256
# 裁剪
transforms.CenterCrop(224) # 中心裁剪
transforms.RandomCrop(224) # 随机裁剪
transforms.RandomResizedCrop(224) # 随机裁剪并缩放
# 翻转
transforms.RandomHorizontalFlip(p=0.5) # 随机水平翻转
transforms.RandomVerticalFlip(p=0.5) # 随机垂直翻转
# 旋转
transforms.RandomRotation(degrees) # 随机旋转
transforms.RandomAffine(degrees=15) # 仿射变换
颜色变换
# 亮度、对比度、饱和度、色调
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
# 灰度化
transforms.Grayscale(num_output_channels=3)
transforms.RandomGrayscale(p=0.1)
# 转换为张量
transforms.ToTensor()
# 标准化
transforms.Normalize(mean, std)
# 随机擦除
transforms.RandomErasing(p=0.5)
完整数据加载示例
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 定义变换
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 加载数据集
train_dataset = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=train_transform
)
val_dataset = datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=val_transform
)
# 创建 DataLoader
train_loader = DataLoader(
train_dataset,
batch_size=64,
shuffle=True,
num_workers=4,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=64,
shuffle=False,
num_workers=4,
pin_memory=True
)
# 训练循环中使用
def train_epoch(model, loader, criterion, optimizer, device):
model.train()
total_loss = 0
correct = 0
total = 0
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
return total_loss / len(loader), 100. * correct / total
高效数据加载技巧
预加载数据到内存
class PreloadedDataset(Dataset):
def __init__(self, data, transform=None):
self.data = data
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
image = self.data[idx]
if self.transform:
image = self.transform(image)
return image
使用缓存加速
from torch.utils.data import DataLoader
# 使用 prefetch_factor 预加载更多数据
train_loader = DataLoader(
dataset,
batch_size=32,
num_workers=4,
prefetch_factor=2, # 每个worker预加载2个batch
persistent_workers=True # 保持worker进程
)
数据不平衡处理
# 方式1:WeightedRandomSampler
from torch.utils.data import WeightedRandomSampler
# 计算类别权重
class_counts = [5000, 1000] # 类别0:5000, 类别1:1000
weights = [1.0 / c for c in class_counts]
samples_weights = [weights[label] for label in train_dataset.labels]
sampler = WeightedRandomSampler(
weights=samples_weights,
num_samples=len(samples_weights),
replacement=True
)
loader = DataLoader(dataset, sampler=sampler)
# 方式2:交叉熵损失加权
class_weights = torch.tensor([1.0, 5.0]).to(device) # 类别1权重更高
criterion = nn.CrossEntropyLoss(weight=class_weights)
常见问题
问题 1:DataLoader 迭代慢
解决方案:
- 增加
num_workers(但不要超过 CPU 核心数) - 使用
pin_memory=True加速 GPU 传输 - 减少图像大小
- 使用 SSD 存储数据
问题 2:GPU 显存不足
解决方案:
- 减小
batch_size - 使用
gradient_checkpointing - 及时释放不需要的数据
问题 3:Windows 多进程问题
解决方案:
# Windows 上将 num_workers 设为 0
train_loader = DataLoader(
dataset,
batch_size=32,
num_workers=0 # Windows
)
下一步
现在我们已经学会了如何加载和预处理数据。接下来让我们学习如何训练模型,包括完整的训练循环、验证和模型保存。