跳到主要内容

数据集和数据加载器

数据是深度学习的基础。PyTorch 提供了 torch.utils.data 模块,包含 DatasetDataLoader 两个核心类,帮助我们高效地加载和处理数据。理解这两个类的工作原理对于构建高效的训练流程至关重要。

为什么需要 Dataset 和 DataLoader?

在机器学习中,处理数据样本的代码往往会变得复杂且难以维护。我们希望数据加载代码与模型训练代码分离,以提高代码的可读性和模块化程度。

Dataset 负责存储数据样本和对应的标签,它是一个抽象类,需要我们继承并实现特定方法。

DataLoader 在 Dataset 外部包装一个可迭代对象,提供批量加载、数据打乱、多进程加载等功能。

原始数据 → Dataset(存储样本和标签)→ DataLoader(批量、打乱、并行)→ 模型训练

Dataset 基础

Dataset 抽象类

torch.utils.data.Dataset 是一个抽象类,所有自定义数据集都需要继承它并实现以下三个方法:

方法说明
__init__(self, ...)初始化数据集,加载元数据(如文件路径列表)
__len__(self)返回数据集的大小,即样本总数
__getitem__(self, idx)根据索引 idx 获取单个样本
from torch.utils.data import Dataset

class MyDataset(Dataset):
def __init__(self, data, labels):
"""初始化数据集"""
self.data = data
self.labels = labels

def __len__(self):
"""返回数据集大小"""
return len(self.data)

def __getitem__(self, idx):
"""获取第 idx 个样本"""
return self.data[idx], self.labels[idx]

设计原则__getitem__ 中读取数据而不是 __init__ 中全部加载。这样只有在需要时才加载单个样本,避免一次性将所有数据加载到内存中,提高内存效率。

内置数据集

PyTorch 的领域库提供了许多预置数据集,这些数据集都继承自 Dataset 类:

import torchvision
from torchvision import datasets

# MNIST 手写数字数据集
mnist_train = datasets.MNIST(
root='./data', # 数据存储路径
train=True, # 训练集
download=True, # 自动下载
transform=None # 数据变换
)

mnist_test = datasets.MNIST(
root='./data',
train=False, # 测试集
download=True
)

# 访问单个样本
image, label = mnist_train[0]
print(f"图像类型: {type(image)}") # PIL.Image.Image
print(f"图像大小: {image.size}") # (28, 28)
print(f"标签: {label}") # 5

# 数据集大小
print(f"训练集大小: {len(mnist_train)}") # 60000
print(f"测试集大小: {len(mnist_test)}") # 10000

常用内置数据集

数据集说明类别数图像大小训练集测试集
MNIST手写数字1028×2860,00010,000
Fashion-MNIST服装分类1028×2860,00010,000
CIFAR-10物体分类1032×3250,00010,000
CIFAR-100物体分类10032×3250,00010,000
ImageNet大规模图像1000可变1.2M50K
COCO目标检测/分割80可变118K5K
# CIFAR-10 数据集
cifar10_train = datasets.CIFAR10(
root='./data',
train=True,
download=True
)

# Fashion-MNIST 数据集
fashion_mnist = datasets.FashionMNIST(
root='./data',
train=True,
download=True
)

# ImageFolder:从文件夹结构自动加载数据
# 假设目录结构为:
# root/class1/xxx.jpg
# root/class2/xxx.jpg
dataset = datasets.ImageFolder(
root='./data/train',
transform=transforms.ToTensor()
)
print(f"类别: {dataset.classes}") # ['class1', 'class2']
print(f"类别到索引: {dataset.class_to_idx}") # {'class1': 0, 'class2': 1}

自定义 Dataset

当使用自己的数据时,需要创建自定义的 Dataset 类。这是实际项目中最常见的情况。

从 NumPy 数组创建

如果数据已经加载到内存中(如 NumPy 数组),可以直接封装:

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader

class NumpyDataset(Dataset):
"""从 NumPy 数组创建数据集"""

def __init__(self, data, labels, transform=None):
"""
Args:
data: numpy 数组,形状 (N, ...)
labels: numpy 数组,形状 (N,)
transform: 可选的数据变换
"""
self.data = torch.from_numpy(data).float()
self.labels = torch.from_numpy(labels).long()
self.transform = transform

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
x = self.data[idx]
y = self.labels[idx]

if self.transform:
x = self.transform(x)

return x, y

# 使用示例
data = np.random.randn(1000, 20).astype(np.float32)
labels = np.random.randint(0, 10, 1000).astype(np.int64)

dataset = NumpyDataset(data, labels)
print(f"数据集大小: {len(dataset)}")

# 访问样本
x, y = dataset[0]
print(f"样本形状: {x.shape}, 标签: {y}")

从图像文件夹创建

最常见的场景是从磁盘加载图像文件。我们只存储文件路径,在 __getitem__ 时才读取图像:

import os
from PIL import Image
from torch.utils.data import Dataset

class ImageFolderDataset(Dataset):
"""从文件夹加载图像数据集"""

def __init__(self, root_dir, transform=None):
"""
Args:
root_dir: 数据根目录,子目录名作为类别名
transform: 可选的数据变换

目录结构:
root_dir/
class1/
img1.jpg
img2.jpg
class2/
img1.jpg
img2.jpg
"""
self.root_dir = root_dir
self.transform = transform

# 获取所有类别(子目录名)
self.classes = sorted([
d for d in os.listdir(root_dir)
if os.path.isdir(os.path.join(root_dir, d))
])
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

# 收集所有图像路径和标签
self.samples = []
for class_name in self.classes:
class_dir = os.path.join(root_dir, class_name)
for img_name in os.listdir(class_dir):
if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
img_path = os.path.join(class_dir, img_name)
self.samples.append((img_path, self.class_to_idx[class_name]))

def __len__(self):
return len(self.samples)

def __getitem__(self, idx):
img_path, label = self.samples[idx]

# 读取图像
image = Image.open(img_path).convert('RGB')

# 应用变换
if self.transform:
image = self.transform(image)

return image, label

# 使用示例
from torchvision import transforms

transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])

dataset = ImageFolderDataset('./data/train', transform=transform)
print(f"类别: {dataset.classes}")
print(f"样本数: {len(dataset)}")

为什么在 __getitem__ 中读取图像?

这种方式被称为"延迟加载"(Lazy Loading)。如果一次性加载所有图像到内存,对于大型数据集(如 ImageNet 有 100 多万张图片),内存会很快耗尽。延迟加载只在需要时才读取,大大降低内存占用。

从 CSV 文件创建

当图像路径和标签存储在 CSV 文件中时:

import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset

class CSVDataset(Dataset):
"""从 CSV 文件加载数据集"""

def __init__(self, csv_file, img_dir, transform=None):
"""
Args:
csv_file: CSV 文件路径,包含 filename 和 label 列
img_dir: 图像目录
transform: 数据变换

CSV 格式示例:
filename,label
img001.jpg,0
img002.jpg,1
"""
self.df = pd.read_csv(csv_file)
self.img_dir = img_dir
self.transform = transform

def __len__(self):
return len(self.df)

def __getitem__(self, idx):
# 获取文件名和标签
img_name = self.df.iloc[idx, 0] # filename 列
label = self.df.iloc[idx, 1] # label 列

# 构建完整路径并读取图像
img_path = os.path.join(self.img_dir, img_name)
image = Image.open(img_path).convert('RGB')

if self.transform:
image = self.transform(image)

return image, label

内存缓存优化

对于小数据集,可以将数据缓存到内存中加速访问:

class CachedDataset(Dataset):
"""带内存缓存的数据集"""

def __init__(self, samples, transform=None, cache_size=1000):
self.samples = samples
self.transform = transform
self.cache = {} # 缓存字典
self.cache_size = cache_size

def __len__(self):
return len(self.samples)

def __getitem__(self, idx):
# 检查缓存
if idx in self.cache:
data, label = self.cache[idx]
else:
# 从磁盘读取
data, label = self._load_sample(idx)

# 添加到缓存
if len(self.cache) < self.cache_size:
self.cache[idx] = (data, label)

if self.transform:
data = self.transform(data)

return data, label

def _load_sample(self, idx):
"""从磁盘加载样本"""
# 实际加载逻辑
pass

DataLoader 详解

DataLoader 是 Dataset 的包装器,提供了批量加载、数据打乱、多进程并行加载等功能。

基本使用

from torch.utils.data import DataLoader

# 创建 DataLoader
train_loader = DataLoader(
dataset=train_dataset, # 数据集对象
batch_size=32, # 每批样本数
shuffle=True, # 是否打乱数据
num_workers=4, # 数据加载进程数
pin_memory=True, # 固定内存(加速 GPU 传输)
drop_last=False # 是否丢弃不完整的最后一批
)

# 迭代获取批次数据
for batch_idx, (images, labels) in enumerate(train_loader):
print(f"批次 {batch_idx}:")
print(f" 图像形状: {images.shape}") # [32, C, H, W]
print(f" 标签形状: {labels.shape}") # [32]

# 训练步骤...

参数详解

参数类型默认值说明
datasetDataset必需数据集对象
batch_sizeint1每批样本数量
shuffleboolFalse是否每个 epoch 打乱数据
samplerSamplerNone自定义采样器,与 shuffle 互斥
batch_samplerSamplerNone自定义批次采样器
num_workersint0数据加载子进程数,0 表示主进程加载
collate_fncallableNone自定义批次整理函数
pin_memoryboolFalse将数据固定在内存,加速 GPU 传输
drop_lastboolFalse是否丢弃最后一个不完整的批次
timeoutfloat0子进程获取数据的超时时间
worker_init_fncallableNone子进程初始化函数
prefetch_factorint2每个工作进程预取的批次数量
persistent_workersboolFalse是否在工作进程间保持持久化

batch_size

批次大小影响训练速度和模型性能:

# 小批次:梯度估计更精确,但训练慢
loader = DataLoader(dataset, batch_size=8)

# 大批次:训练快,但可能降低泛化能力
loader = DataLoader(dataset, batch_size=256)

# 常见选择:根据 GPU 显存调整
# MNIST/CIFAR: 64-128
# ImageNet: 32-256
# 大模型: 可能需要更小

shuffle

训练时通常打乱数据以减少过拟合,验证和测试时不打乱:

# 训练集:打乱
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 验证集/测试集:不打乱
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

num_workers

多进程并行加载数据可以显著提高数据加载速度:

# 单进程加载(调试时使用)
loader = DataLoader(dataset, batch_size=32, num_workers=0)

# 多进程加载
# 推荐:CPU 核心数 - 1 或 CPU 核心数
loader = DataLoader(dataset, batch_size=32, num_workers=4)

注意事项

  • Windows 上建议设置 num_workers=0,多进程可能有问题
  • 过多的工作进程可能导致内存不足
  • 使用多进程时,数据加载代码需要能正确序列化

pin_memory

当使用 GPU 训练时,pin_memory=True 可以加速数据从 CPU 到 GPU 的传输:

# GPU 训练时推荐启用
loader = DataLoader(dataset, batch_size=32, pin_memory=True)

# 数据自动从 pinned memory 传输到 GPU
for data, target in loader:
data = data.to(device, non_blocking=True) # 非阻塞传输
target = target.to(device, non_blocking=True)

Pinned memory(锁页内存)是 CPU 内存中不会被交换到磁盘的部分,GPU 可以直接通过 DMA(直接内存访问)从中读取数据,速度更快。

DataLoader 迭代方式

# 方式 1:for 循环(推荐)
for batch_idx, (data, target) in enumerate(train_loader):
# 处理批次数据
pass

# 方式 2:iter + next
data_iter = iter(train_loader)
first_batch = next(data_iter)
data, target = first_batch

# 方式 3:使用 range
num_batches = len(train_loader)
for i in range(num_batches):
batch = next(iter(train_loader)) # 不推荐,每次都创建新迭代器

获取 DataLoader 属性

# 批次数量
num_batches = len(train_loader)

# 每批样本数
batch_size = train_loader.batch_size

# 数据集大小
dataset_size = len(train_loader.dataset)

# 实际批次数量(考虑 drop_last)
if train_loader.drop_last:
actual_batches = dataset_size // batch_size
else:
actual_batches = (dataset_size + batch_size - 1) // batch_size

Sampler 采样器

Sampler 控制数据采样的顺序,DataLoader 的 shuffle 参数实际上是使用 RandomSampler 实现的。

内置采样器

from torch.utils.data import Sampler, SequentialSampler, RandomSampler

# SequentialSampler:顺序采样
sampler = SequentialSampler(dataset)
loader = DataLoader(dataset, batch_size=32, sampler=sampler)

# RandomSampler:随机采样
sampler = RandomSampler(dataset, replacement=False, num_samples=None)
loader = DataLoader(dataset, batch_size=32, sampler=sampler)

# 等价于 shuffle=True
loader = DataLoader(dataset, batch_size=32, shuffle=True)

WeightedRandomSampler:处理类别不平衡

当各类别样本数量差异较大时,使用 WeightedRandomSampler 可以让每个类别的采样概率更均衡:

from torch.utils.data import WeightedRandomSampler
import numpy as np

# 假设有 1000 个样本,类别 0 有 900 个,类别 1 有 100 个
labels = np.array([0] * 900 + [1] * 100)

# 计算每个类别的权重(与样本数量成反比)
class_counts = np.bincount(labels)
class_weights = 1.0 / class_counts # [0.00111, 0.01]

# 计算每个样本的权重
sample_weights = class_weights[labels] # 每个样本对应的权重

# 创建采样器
sampler = WeightedRandomSampler(
weights=torch.from_numpy(sample_weights).float(),
num_samples=len(labels), # 每个 epoch 采样的样本数
replacement=True # 是否允许重复采样
)

loader = DataLoader(dataset, batch_size=32, sampler=sampler)

工作原理

每个样本被采样的概率与其权重成正比。如果类别 0 有 900 个样本,类别 1 有 100 个样本:

  • 类别 0 的采样权重:1/900 ≈ 0.00111
  • 类别 1 的采样权重:1/100 = 0.01

在 replacement=True 时,每个 epoch 中类别 0 和类别 1 的样本数量会趋于相等。

DistributedSampler:分布式训练

多 GPU 分布式训练时,需要确保每个 GPU 获取不同的数据子集:

from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist

# 初始化分布式环境
dist.init_process_group(backend='nccl')
local_rank = dist.get_rank()
world_size = dist.get_world_size()

# 创建分布式采样器
sampler = DistributedSampler(
dataset,
num_replicas=world_size, # 总 GPU 数
rank=local_rank, # 当前 GPU 编号
shuffle=True
)

train_loader = DataLoader(
dataset,
batch_size=32,
sampler=sampler,
num_workers=4,
pin_memory=True
)

# 每个 epoch 开始时需要设置 sampler 的 epoch
for epoch in range(num_epochs):
sampler.set_epoch(epoch) # 确保每个 epoch 的打乱不同
for batch in train_loader:
# 训练...
pass

自定义采样器

可以创建自定义采样器实现特定的采样策略:

from torch.utils.data import Sampler

class SubsetSampler(Sampler):
"""只采样指定索引的样本"""

def __init__(self, indices):
self.indices = indices

def __iter__(self):
return iter(self.indices)

def __len__(self):
return len(self.indices)

class BalancedBatchSampler(Sampler):
"""每个批次包含相同数量的正负样本"""

def __init__(self, labels, batch_size):
self.labels = labels
self.batch_size = batch_size

# 分别获取正负样本索引
self.pos_indices = np.where(labels == 1)[0].tolist()
self.neg_indices = np.where(labels == 0)[0].tolist()

self.num_batches = min(len(self.pos_indices), len(self.neg_indices)) // (batch_size // 2)

def __iter__(self):
# 打乱正负样本索引
np.random.shuffle(self.pos_indices)
np.random.shuffle(self.neg_indices)

for i in range(self.num_batches):
batch = []
# 添加正样本
batch.extend(self.pos_indices[i*self.batch_size//2:(i+1)*self.batch_size//2])
# 添加负样本
batch.extend(self.neg_indices[i*self.batch_size//2:(i+1)*self.batch_size//2])
np.random.shuffle(batch) # 打乱批次内顺序
yield batch

def __len__(self):
return self.num_batches

collate_fn 自定义批处理

当数据集中的样本长度不一致(如变长序列)或需要特殊的批处理逻辑时,可以使用 collate_fn 自定义批处理方式。

默认 collate 行为

默认的 collate_fn 会将样本堆叠成一个批次:

# 假设数据集返回的样本形状一致
# Dataset 返回:(C, H, W) 和标量标签
# DataLoader 返回:(batch_size, C, H, W) 和 (batch_size,) 标签

# 默认行为相当于:
def default_collate(batch):
# batch 是一个列表:[(data1, label1), (data2, label2), ...]
data, labels = zip(*batch)
return torch.stack(data, dim=0), torch.stack(labels, dim=0)

处理变长序列

对于 RNN/LSTM 处理的变长文本序列,需要进行填充(Padding):

import torch
from torch.nn.utils.rnn import pad_sequence

def collate_fn_variable_length(batch):
"""
处理变长序列的自定义 collate_fn

Args:
batch: [(seq1, label1), (seq2, label2), ...]
每个序列长度可能不同

Returns:
padded_seqs: (batch_size, max_seq_len, feature_dim)
lengths: (batch_size,) 每个序列的实际长度
labels: (batch_size,)
"""
# 分离序列和标签
sequences, labels = zip(*batch)

# 获取每个序列的长度
lengths = torch.tensor([len(seq) for seq in sequences])

# 填充序列到相同长度
# pad_sequence 需要 (seq_len, feature_dim) 形状
padded_seqs = pad_sequence(sequences, batch_first=True, padding_value=0)

# 堆叠标签
labels = torch.tensor(labels)

return padded_seqs, lengths, labels

# 使用示例
class TextDataset(Dataset):
def __init__(self, texts, labels):
self.texts = texts # 列表,每个元素是 tensor
self.labels = labels

def __len__(self):
return len(self.texts)

def __getitem__(self, idx):
return self.texts[idx], self.labels[idx]

# 变长文本数据
texts = [
torch.randn(5, 100), # 序列长度 5
torch.randn(8, 100), # 序列长度 8
torch.randn(3, 100), # 序列长度 3
]
labels = [0, 1, 0]

dataset = TextDataset(texts, labels)
loader = DataLoader(dataset, batch_size=3, collate_fn=collate_fn_variable_length)

padded_seqs, lengths, labels = next(iter(loader))
print(f"填充后形状: {padded_seqs.shape}") # (3, 8, 100)
print(f"序列长度: {lengths}") # tensor([5, 8, 3])

使用 pack_padded_sequence

对于 RNN/LSTM,使用 pack_padded_sequence 可以更高效地处理变长序列:

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

def collate_fn_packed(batch):
"""返回 packed sequence 的 collate_fn"""
sequences, labels = zip(*batch)
lengths = torch.tensor([len(seq) for seq in sequences])
labels = torch.tensor(labels)

# 先填充
padded_seqs = pad_sequence(sequences, batch_first=True, padding_value=0)

# 按长度降序排序(pack_padded_sequence 要求)
lengths, sorted_idx = lengths.sort(descending=True)
padded_seqs = padded_seqs[sorted_idx]
labels = labels[sorted_idx]

# 打包
packed_seqs = pack_padded_sequence(padded_seqs, lengths, batch_first=True)

return packed_seqs, labels

# 在模型中使用
class RNNModel(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.rnn = nn.LSTM(input_size, hidden_size, batch_first=True)

def forward(self, packed_seqs):
# LSTM 可以直接处理 packed sequence
packed_output, (hidden, cell) = self.rnn(packed_seqs)
return packed_output, hidden

处理字典类型的样本

当 Dataset 返回字典而非元组时:

def collate_fn_dict(batch):
"""
处理字典类型的样本

Args:
batch: [{'image': img1, 'label': label1, 'mask': mask1}, ...]
"""
return {
'image': torch.stack([item['image'] for item in batch]),
'label': torch.tensor([item['label'] for item in batch]),
'mask': torch.stack([item['mask'] for item in batch]),
}

class DictDataset(Dataset):
def __getitem__(self, idx):
return {
'image': self.images[idx],
'label': self.labels[idx],
'mask': self.masks[idx],
}

数据变换 (Transforms)

数据变换是预处理数据的关键步骤,包括数据增强、归一化、格式转换等。

torchvision.transforms

torchvision.transforms 提供了丰富的图像变换:

from torchvision import transforms

# 组合多个变换
train_transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整大小
transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转
transforms.RandomRotation(15), # 随机旋转 ±15 度
transforms.ColorJitter( # 颜色抖动
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.1
),
transforms.ToTensor(), # 转为张量,归一化到 [0, 1]
transforms.Normalize( # 标准化
mean=[0.485, 0.456, 0.406], # ImageNet 均值
std=[0.229, 0.224, 0.225] # ImageNet 标准差
),
transforms.RandomErasing(p=0.5), # 随机擦除
])

# 验证/测试时通常不做数据增强
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])

常用变换分类

几何变换

# 调整大小
transforms.Resize(256) # 短边调整为 256
transforms.Resize((224, 224)) # 固定大小

# 裁剪
transforms.CenterCrop(224) # 中心裁剪
transforms.RandomCrop(224) # 随机裁剪
transforms.RandomResizedCrop(224) # 随机裁剪并缩放
transforms.FiveCrop(224) # 五裁剪(四角 + 中心)
transforms.TenCrop(224) # 十裁剪(五裁剪 + 翻转)

# 翻转
transforms.RandomHorizontalFlip(p=0.5) # 随机水平翻转
transforms.RandomVerticalFlip(p=0.5) # 随机垂直翻转

# 旋转和仿射
transforms.RandomRotation(degrees=30) # 随机旋转
transforms.RandomAffine( # 随机仿射变换
degrees=10, # 旋转角度
translate=(0.1, 0.1), # 平移
scale=(0.9, 1.1), # 缩放
shear=10 # 剪切
)
transforms.RandomPerspective(distortion_scale=0.5) # 随机透视变换

颜色变换

# 亮度、对比度、饱和度、色调
transforms.ColorJitter(
brightness=0.2, # 亮度变化范围 [max(0, 1-brightness), 1+brightness]
contrast=0.2, # 对比度
saturation=0.2, # 饱和度
hue=0.1 # 色调变化范围 [-hue, hue]
)

# 灰度化
transforms.Grayscale(num_output_channels=1) # 转为单通道灰度
transforms.Grayscale(num_output_channels=3) # 保持三通道但为灰度
transforms.RandomGrayscale(p=0.1) # 随机灰度化

格式转换

# PIL Image 或 ndarray 转 Tensor
# 将像素值从 [0, 255] 归一化到 [0, 1]
# 将 HWC 格式转为 CHW 格式
transforms.ToTensor()

# Tensor 转 PIL Image
transforms.ToPILImage(mode='RGB')

# 标准化
# output = (input - mean) / std
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# 反标准化(用于可视化)
class UnNormalize:
def __init__(self, mean, std):
self.mean = mean
self.std = std

def __call__(self, tensor):
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
return tensor

unnorm = UnNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

数据增强

# 随机擦除
transforms.RandomErasing(
p=0.5, # 擦除概率
scale=(0.02, 0.33), # 擦除区域面积占比范围
ratio=(0.3, 3.3), # 擦除区域长宽比范围
value=0 # 擦除填充值
)

# AutoAugment:自动增强策略
transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET)

# RandAugment:随机增强
transforms.RandAugment(num_ops=2, magnitude=9)

# TrivialAugmentWide:简单增强
transforms.TrivialAugmentWide()

# MixUp 和 CutMix 通常在训练循环中实现

自定义 Transform

创建自定义变换只需要实现 __call__ 方法:

class GaussianNoise:
"""添加高斯噪声"""

def __init__(self, mean=0.0, std=0.1):
self.mean = mean
self.std = std

def __call__(self, tensor):
noise = torch.randn(tensor.size()) * self.std + self.mean
return tensor + noise

class Cutout:
"""Cutout 数据增强"""

def __init__(self, n_holes=1, length=16):
self.n_holes = n_holes
self.length = length

def __call__(self, img):
# img: Tensor of shape (C, H, W)
h, w = img.size(1), img.size(2)

mask = torch.ones((h, w), dtype=torch.float32)

for _ in range(self.n_holes):
y = torch.randint(h, (1,)).item()
x = torch.randint(w, (1,)).item()

y1 = max(0, y - self.length // 2)
y2 = min(h, y + self.length // 2)
x1 = max(0, x - self.length // 2)
x2 = min(w, x + self.length // 2)

mask[y1:y2, x1:x2] = 0.

mask = mask.expand_as(img)
return img * mask

# 使用自定义变换
transform = transforms.Compose([
transforms.ToTensor(),
GaussianNoise(std=0.05),
Cutout(n_holes=1, length=16),
])

在线增强 vs 离线增强

在线增强(Online Augmentation):在训练时实时进行增强,每次 epoch 图像都不同。

# 在线增强(推荐)
train_dataset = datasets.CIFAR10(
root='./data',
train=True,
transform=transforms.Compose([
transforms.RandomHorizontalFlip(), # 每次访问都随机
transforms.ToTensor(),
])
)

离线增强(Offline Augmentation):提前生成增强后的图像并保存,适合数据量很少的情况。

# 离线增强
import os
from PIL import Image

original_image = Image.open('original.jpg')
augmented_images = []

for i in range(10):
# 应用增强
aug_img = transforms.RandomRotation(30)(original_image)
aug_img.save(f'augmented_{i}.jpg')

完整示例:训练数据加载流程

下面是一个完整的图像分类训练数据加载示例:

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 1. 定义数据变换
# 训练集:数据增强
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224), # 随机裁剪并缩放
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1), # 颜色抖动
transforms.ToTensor(),
transforms.Normalize( # ImageNet 归一化
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
transforms.RandomErasing(p=0.5), # 随机擦除
])

# 验证集:不做数据增强
val_transform = transforms.Compose([
transforms.Resize(256), # 先放大
transforms.CenterCrop(224), # 中心裁剪
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])

# 2. 加载数据集
train_dataset = datasets.ImageFolder(
root='./data/train',
transform=train_transform
)

val_dataset = datasets.ImageFolder(
root='./data/val',
transform=val_transform
)

# 打印数据集信息
print(f"训练集大小: {len(train_dataset)}")
print(f"验证集大小: {len(val_dataset)}")
print(f"类别: {train_dataset.classes}")

# 3. 创建 DataLoader
train_loader = DataLoader(
train_dataset,
batch_size=64,
shuffle=True,
num_workers=4,
pin_memory=True,
drop_last=True,
)

val_loader = DataLoader(
val_dataset,
batch_size=64,
shuffle=False,
num_workers=4,
pin_memory=True,
)

# 4. 获取设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 5. 训练循环中使用
def train_epoch(model, loader, criterion, optimizer, device):
model.train()
total_loss = 0.0
correct = 0
total = 0

for images, labels in loader:
# 移动到设备
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)

# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)

# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()

# 统计
total_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

return total_loss / len(loader), 100. * correct / total

def validate(model, loader, criterion, device):
model.eval()
total_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
for images, labels in loader:
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)

outputs = model(images)
loss = criterion(outputs, labels)

total_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

return total_loss / len(loader), 100. * correct / total

高级数据加载技巧

处理类别不平衡

除了 WeightedRandomSampler,还可以在损失函数中设置类别权重:

import numpy as np
from sklearn.utils.class_weight import compute_class_weight

# 获取所有标签
labels = [train_dataset.targets[i] for i in range(len(train_dataset))]

# 计算类别权重
class_weights = compute_class_weight(
'balanced',
classes=np.unique(labels),
y=labels
)
class_weights = torch.tensor(class_weights, dtype=torch.float32)

# 使用加权交叉熵损失
criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))

数据预取和缓存

使用 prefetch_factorpersistent_workers 优化数据加载:

loader = DataLoader(
dataset,
batch_size=64,
num_workers=4,
prefetch_factor=2, # 每个工作进程预取 2 个批次
persistent_workers=True, # 保持工作进程,避免重复启动
)

IterableDataset:流式数据

对于超大数据集无法全部加载到内存时,可以使用 IterableDataset

from torch.utils.data import IterableDataset

class StreamDataset(IterableDataset):
"""流式数据集,适用于超大数据集"""

def __init__(self, file_path):
self.file_path = file_path

def __iter__(self):
# 打开文件并逐行读取
with open(self.file_path, 'r') as f:
for line in f:
# 处理每一行
data = self._process_line(line)
yield data

def _process_line(self, line):
# 解析数据
pass

# 使用方式相同
loader = DataLoader(stream_dataset, batch_size=32)

组合多个数据集

from torch.utils.data import ConcatDataset

# 合并多个数据集
dataset1 = datasets.CIFAR10(root='./data', train=True, transform=transform)
dataset2 = datasets.CIFAR100(root='./data', train=True, transform=transform)

combined_dataset = ConcatDataset([dataset1, dataset2])
print(f"合并后大小: {len(combined_dataset)}") # 100000 + 50000

loader = DataLoader(combined_dataset, batch_size=32, shuffle=True)

数据集划分

from torch.utils.data import random_split

# 划分训练集和验证集
full_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = random_split(
full_dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(42) # 固定随机种子
)

print(f"训练集大小: {len(train_dataset)}") # 40000
print(f"验证集大小: {len(val_dataset)}") # 10000

常见问题与解决方案

问题 1:DataLoader 迭代慢

原因:数据加载成为瓶颈

解决方案

# 1. 增加 num_workers
loader = DataLoader(dataset, num_workers=8)

# 2. 启用 pin_memory
loader = DataLoader(dataset, pin_memory=True)

# 3. 启用持久化工作进程
loader = DataLoader(dataset, persistent_workers=True)

# 4. 增加预取
loader = DataLoader(dataset, prefetch_factor=4)

# 5. 使用更快的存储(SSD)
# 6. 减少数据变换复杂度

问题 2:GPU 显存不足

原因:batch_size 太大

解决方案

# 1. 减小 batch_size
loader = DataLoader(dataset, batch_size=16)

# 2. 使用梯度累积
accumulation_steps = 4
for i, batch in enumerate(loader):
loss = model(batch) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()

# 3. 及时清理缓存
torch.cuda.empty_cache()

问题 3:Windows 多进程问题

原因:Windows 使用 spawn 而非 fork 启动子进程

解决方案

# Windows 上设置 num_workers=0 或使用 if __name__ == '__main__': 保护
if __name__ == '__main__':
loader = DataLoader(dataset, num_workers=4)
for batch in loader:
pass

# 或者
loader = DataLoader(dataset, num_workers=0) # 单进程

问题 4:数据损坏或读取失败

解决方案

class SafeDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
self.corrupted = set()

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
if idx in self.corrupted:
# 返回一个替代样本
return self.dataset[0]

try:
return self.dataset[idx]
except Exception as e:
print(f"样本 {idx} 损坏: {e}")
self.corrupted.add(idx)
return self.dataset[0] # 返回替代样本

问题 5:内存不足

原因:数据集太大

解决方案

# 1. 使用延迟加载(在 __getitem__ 中读取)
# 2. 使用 IterableDataset 流式处理
# 3. 使用内存映射文件
import numpy as np

class MMapDataset(Dataset):
def __init__(self, mmap_path):
# 内存映射,不加载到内存
self.data = np.load(mmap_path, mmap_mode='r')

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
return torch.from_numpy(self.data[idx].copy())

小结

本章我们详细学习了 PyTorch 数据加载和处理的核心知识:

  1. Dataset:抽象数据集类,需要实现 __init____len____getitem__ 方法
  2. DataLoader:提供批量加载、打乱数据、多进程加载等功能
  3. Sampler:控制数据采样顺序,包括处理类别不平衡的 WeightedRandomSampler
  4. collate_fn:自定义批处理逻辑,处理变长序列等特殊情况
  5. Transforms:数据变换和数据增强,包括几何变换、颜色变换、格式转换等
  6. 高级技巧:内存缓存、流式处理、数据集划分、处理类别不平衡等

理解数据加载的原理和技巧,是构建高效训练流程的基础。下一章我们将学习完整的模型训练流程。

参考资料