跳到主要内容

数据集和数据加载器

本章将介绍 PyTorch 的数据加载和处理工具,包括 Dataset、DataLoader 和数据变换(Transforms)。

Dataset 基础

PyTorch 提供了 torch.utils.data.Dataset 抽象类来表示数据集。

内置数据集

PyTorch 内置了常用的数据集,可以直接使用:

import torchvision
from torchvision import datasets

# MNIST 手写数字数据集
mnist_train = datasets.MNIST(
root='./data', # 数据存储路径
train=True, # 训练集
download=True, # 自动下载
transform=None # 数据变换
)

mnist_test = datasets.MNIST(
root='./data',
train=False,
download=True
)

# 访问数据
image, label = mnist_train[0]
print(f"图像形状: {image.size()}") # PIL.Image.Size (28, 28)
print(f"标签: {label}") # 5

# CIFAR-10 数据集
cifar10_train = datasets.CIFAR10(
root='./data',
train=True,
download=True
)

# Fashion-MNIST
fashion_mnist = datasets.FashionMNIST(
root='./data',
train=True,
download=True
)

常用内置数据集

数据集说明类别数图像大小
MNIST手写数字1028×28
Fashion-MNIST服装分类1028×28
CIFAR-10物体分类1032×32
CIFAR-100物体分类10032×32
ImageNet大规模图像1000可变
COCO目标检测80可变

自定义 Dataset

当使用自己的数据时,需要创建自定义的 Dataset 类。

基本结构

from torch.utils.data import Dataset

class MyDataset(Dataset):
def __init__(self, data, labels, transform=None):
"""
初始化数据集
Args:
data: 数据(如文件路径列表或numpy数组)
labels: 标签
transform: 可选的数据变换
"""
self.data = data
self.labels = labels
self.transform = transform

def __len__(self):
"""返回数据集大小"""
return len(self.data)

def __getitem__(self, idx):
"""
获取单个样本
Args:
idx: 索引
Returns:
(image, label): 图像和标签元组
"""
# 获取数据
image = self.data[idx]
label = self.labels[idx]

# 应用变换
if self.transform:
image = self.transform(image)

return image, label

完整示例:从文件夹加载图像

import os
from PIL import Image
from torch.utils.data import Dataset

class ImageFolderDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.classes = sorted(os.listdir(root_dir))
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

# 收集所有图像路径和标签
self.samples = []
for class_name in self.classes:
class_dir = os.path.join(root_dir, class_name)
if os.path.isdir(class_dir):
for img_name in os.listdir(class_dir):
if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(class_dir, img_name)
self.samples.append((img_path, self.class_to_idx[class_name]))

def __len__(self):
return len(self.samples)

def __getitem__(self, idx):
img_path, label = self.samples[idx]
image = Image.open(img_path).convert('RGB')

if self.transform:
image = self.transform(image)

return image, label

DataLoader

DataLoader 提供了批量加载、多进程加载、数据打乱等功能。

基本使用

from torch.utils.data import DataLoader

# 创建 DataLoader
train_loader = DataLoader(
dataset=train_dataset, # 数据集
batch_size=32, # 批量大小
shuffle=True, # 是否打乱
num_workers=4, # 工作进程数(并行加载)
pin_memory=True, # 固定内存(加速GPU传输)
drop_last=True # 是否丢弃最后一个不完整的batch
)

# 迭代数据
for batch_idx, (images, labels) in enumerate(train_loader):
print(f"Batch {batch_idx}:")
print(f" 图像形状: {images.shape}") # [batch_size, channels, height, width]
print(f" 标签形状: {labels.shape}") # [batch_size]
print(f" 标签值: {labels[:5]}") # 前5个标签

# 训练步骤...
break # 只演示第一个batch

参数详解

参数说明建议
batch_size每个batch的样本数根据GPU内存调整
shuffle是否打乱数据训练时True,验证时False
num_workers数据加载进程数CPU核心数-1,Windows建议0
pin_memory固定内存加速传输GPU训练时设为True
drop_last丢弃不完整batchTrue保证batch大小一致
collate_fn自定义批处理函数处理不同长度数据

多 GPU 加载

# 使用 DistributedSampler 实现多 GPU 数据加载
from torch.utils.data.distributed import DistributedSampler

train_sampler = DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
shuffle=True
)

train_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=train_sampler,
num_workers=num_workers,
pin_memory=True
)

数据变换 (Transforms)

数据变换是预处理图像数据的关键步骤。

常用变换

from torchvision import transforms

# 组合多个变换
train_transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整大小
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomRotation(15), # 随机旋转
transforms.ColorJitter(brightness=0.2), # 颜色抖动
transforms.ToTensor(), # 转为张量(归一化到[0,1])
transforms.Normalize( # 标准化
mean=[0.485, 0.456, 0.406], # ImageNet均值
std=[0.229, 0.224, 0.225] # ImageNet标准差
)
])

val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])

变换类型分类

几何变换

# 调整大小
transforms.Resize((224, 224)) # 固定大小
transforms.Resize(256) # 短边256

# 裁剪
transforms.CenterCrop(224) # 中心裁剪
transforms.RandomCrop(224) # 随机裁剪
transforms.RandomResizedCrop(224) # 随机裁剪并缩放

# 翻转
transforms.RandomHorizontalFlip(p=0.5) # 随机水平翻转
transforms.RandomVerticalFlip(p=0.5) # 随机垂直翻转

# 旋转
transforms.RandomRotation(degrees) # 随机旋转
transforms.RandomAffine(degrees=15) # 仿射变换

颜色变换

# 亮度、对比度、饱和度、色调
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)

# 灰度化
transforms.Grayscale(num_output_channels=3)
transforms.RandomGrayscale(p=0.1)

# 转换为张量
transforms.ToTensor()

# 标准化
transforms.Normalize(mean, std)

# 随机擦除
transforms.RandomErasing(p=0.5)

完整数据加载示例

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 定义变换
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 加载数据集
train_dataset = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=train_transform
)

val_dataset = datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=val_transform
)

# 创建 DataLoader
train_loader = DataLoader(
train_dataset,
batch_size=64,
shuffle=True,
num_workers=4,
pin_memory=True
)

val_loader = DataLoader(
val_dataset,
batch_size=64,
shuffle=False,
num_workers=4,
pin_memory=True
)

# 训练循环中使用
def train_epoch(model, loader, criterion, optimizer, device):
model.train()
total_loss = 0
correct = 0
total = 0

for images, labels in loader:
images, labels = images.to(device), labels.to(device)

outputs = model(images)
loss = criterion(outputs, labels)

optimizer.zero_grad()
loss.backward()
optimizer.step()

total_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

return total_loss / len(loader), 100. * correct / total

高效数据加载技巧

预加载数据到内存

class PreloadedDataset(Dataset):
def __init__(self, data, transform=None):
self.data = data
self.transform = transform

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
image = self.data[idx]

if self.transform:
image = self.transform(image)

return image

使用缓存加速

from torch.utils.data import DataLoader

# 使用 prefetch_factor 预加载更多数据
train_loader = DataLoader(
dataset,
batch_size=32,
num_workers=4,
prefetch_factor=2, # 每个worker预加载2个batch
persistent_workers=True # 保持worker进程
)

数据不平衡处理

# 方式1:WeightedRandomSampler
from torch.utils.data import WeightedRandomSampler

# 计算类别权重
class_counts = [5000, 1000] # 类别0:5000, 类别1:1000
weights = [1.0 / c for c in class_counts]
samples_weights = [weights[label] for label in train_dataset.labels]

sampler = WeightedRandomSampler(
weights=samples_weights,
num_samples=len(samples_weights),
replacement=True
)

loader = DataLoader(dataset, sampler=sampler)

# 方式2:交叉熵损失加权
class_weights = torch.tensor([1.0, 5.0]).to(device) # 类别1权重更高
criterion = nn.CrossEntropyLoss(weight=class_weights)

常见问题

问题 1:DataLoader 迭代慢

解决方案

  • 增加 num_workers(但不要超过 CPU 核心数)
  • 使用 pin_memory=True 加速 GPU 传输
  • 减少图像大小
  • 使用 SSD 存储数据

问题 2:GPU 显存不足

解决方案

  • 减小 batch_size
  • 使用 gradient_checkpointing
  • 及时释放不需要的数据

问题 3:Windows 多进程问题

解决方案

# Windows 上将 num_workers 设为 0
train_loader = DataLoader(
dataset,
batch_size=32,
num_workers=0 # Windows
)

下一步

现在我们已经学会了如何加载和预处理数据。接下来让我们学习如何训练模型,包括完整的训练循环、验证和模型保存。