数据集和数据加载器
数据是深度学习的基础。PyTorch 提供了 torch.utils.data 模块,包含 Dataset 和 DataLoader 两个核心类,帮助我们高效地加载和处理数据。理解这两个类的工作原理对于构建高效的训练流程至关重要。
为什么需要 Dataset 和 DataLoader?
在机器学习中,处理数据样本的代码往往会变得复杂且难以维护。我们希望数据加载代码与模型训练代码分离,以提高代码的可读性和模块化程度。
Dataset 负责存储数据样本和对应的标签,它是一个抽象类,需要我们继承并实现特定方法。
DataLoader 在 Dataset 外部包装一个可迭代对象,提供批量加载、数据打乱、多进程加载等功能。
原始数据 → Dataset(存储样本和标签)→ DataLoader(批量、打乱、并行)→ 模型训练
Dataset 基础
Dataset 抽象类
torch.utils.data.Dataset 是一个抽象类,所有自定义数据集都需要继承它并实现以下三个方法:
| 方法 | 说明 |
|---|---|
__init__(self, ...) | 初始化数据集,加载元数据(如文件路径列表) |
__len__(self) | 返回数据集的大小,即样本总数 |
__getitem__(self, idx) | 根据索引 idx 获取单个样本 |
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data, labels):
"""初始化数据集"""
self.data = data
self.labels = labels
def __len__(self):
"""返回数据集大小"""
return len(self.data)
def __getitem__(self, idx):
"""获取第 idx 个样本"""
return self.data[idx], self.labels[idx]
设计原则:__getitem__ 中读取数据而不是 __init__ 中全部加载。这样只有在需要时才加载单个样本,避免一次性将所有数据加载到内存中,提高内存效率。
内置数据集
PyTorch 的领域库提供了许多预置数据集,这些数据集都继承自 Dataset 类:
import torchvision
from torchvision import datasets
# MNIST 手写数字数据集
mnist_train = datasets.MNIST(
root='./data', # 数据存储路径
train=True, # 训练集
download=True, # 自动下载
transform=None # 数据变换
)
mnist_test = datasets.MNIST(
root='./data',
train=False, # 测试集
download=True
)
# 访问单个样本
image, label = mnist_train[0]
print(f"图像类型: {type(image)}") # PIL.Image.Image
print(f"图像大小: {image.size}") # (28, 28)
print(f"标签: {label}") # 5
# 数据集大小
print(f"训练集大小: {len(mnist_train)}") # 60000
print(f"测试集大小: {len(mnist_test)}") # 10000
常用内置数据集
| 数据集 | 说明 | 类别数 | 图像大小 | 训练集 | 测试集 |
|---|---|---|---|---|---|
| MNIST | 手写数字 | 10 | 28×28 | 60,000 | 10,000 |
| Fashion-MNIST | 服装分类 | 10 | 28×28 | 60,000 | 10,000 |
| CIFAR-10 | 物体分类 | 10 | 32×32 | 50,000 | 10,000 |
| CIFAR-100 | 物体分类 | 100 | 32×32 | 50,000 | 10,000 |
| ImageNet | 大规模图像 | 1000 | 可变 | 1.2M | 50K |
| COCO | 目标检测/分割 | 80 | 可变 | 118K | 5K |
# CIFAR-10 数据集
cifar10_train = datasets.CIFAR10(
root='./data',
train=True,
download=True
)
# Fashion-MNIST 数据集
fashion_mnist = datasets.FashionMNIST(
root='./data',
train=True,
download=True
)
# ImageFolder:从文件夹结构自动加载数据
# 假设目录结构为:
# root/class1/xxx.jpg
# root/class2/xxx.jpg
dataset = datasets.ImageFolder(
root='./data/train',
transform=transforms.ToTensor()
)
print(f"类别: {dataset.classes}") # ['class1', 'class2']
print(f"类别到索引: {dataset.class_to_idx}") # {'class1': 0, 'class2': 1}
自定义 Dataset
当使用自己的数据时,需要创建自定义的 Dataset 类。这是实际项目中最常见的情况。
从 NumPy 数组创建
如果数据已经加载到内存中(如 NumPy 数组),可以直接封装:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
class NumpyDataset(Dataset):
"""从 NumPy 数组创建数据集"""
def __init__(self, data, labels, transform=None):
"""
Args:
data: numpy 数组,形状 (N, ...)
labels: numpy 数组,形状 (N,)
transform: 可选的数据变换
"""
self.data = torch.from_numpy(data).float()
self.labels = torch.from_numpy(labels).long()
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
x = self.data[idx]
y = self.labels[idx]
if self.transform:
x = self.transform(x)
return x, y
# 使用示例
data = np.random.randn(1000, 20).astype(np.float32)
labels = np.random.randint(0, 10, 1000).astype(np.int64)
dataset = NumpyDataset(data, labels)
print(f"数据集大小: {len(dataset)}")
# 访问样本
x, y = dataset[0]
print(f"样本形状: {x.shape}, 标签: {y}")
从图像文件夹创建
最常见的场景是从磁盘加载图像文件。我们只存储文件路径,在 __getitem__ 时才读取图像:
import os
from PIL import Image
from torch.utils.data import Dataset
class ImageFolderDataset(Dataset):
"""从文件夹加载图像数据集"""
def __init__(self, root_dir, transform=None):
"""
Args:
root_dir: 数据根目录,子目录名作为类别名
transform: 可选的数据变换
目录结构:
root_dir/
class1/
img1.jpg
img2.jpg
class2/
img1.jpg
img2.jpg
"""
self.root_dir = root_dir
self.transform = transform
# 获取所有类别(子目录名)
self.classes = sorted([
d for d in os.listdir(root_dir)
if os.path.isdir(os.path.join(root_dir, d))
])
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
# 收集所有图像路径和标签
self.samples = []
for class_name in self.classes:
class_dir = os.path.join(root_dir, class_name)
for img_name in os.listdir(class_dir):
if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
img_path = os.path.join(class_dir, img_name)
self.samples.append((img_path, self.class_to_idx[class_name]))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
# 读取图像
image = Image.open(img_path).convert('RGB')
# 应用变换
if self.transform:
image = self.transform(image)
return image, label
# 使用示例
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
dataset = ImageFolderDataset('./data/train', transform=transform)
print(f"类别: {dataset.classes}")
print(f"样本数: {len(dataset)}")
为什么在 __getitem__ 中读取图像?
这种方式被称为"延迟加载"(Lazy Loading)。如果一次性加载所有图像到内存,对于大型数据集(如 ImageNet 有 100 多万张图片),内存会很快耗尽。延迟加载只在需要时才读取,大大降低内存占用。
从 CSV 文件创建
当图像路径和标签存储在 CSV 文件中时:
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
class CSVDataset(Dataset):
"""从 CSV 文件加载数据集"""
def __init__(self, csv_file, img_dir, transform=None):
"""
Args:
csv_file: CSV 文件路径,包含 filename 和 label 列
img_dir: 图像目录
transform: 数据变换
CSV 格式示例:
filename,label
img001.jpg,0
img002.jpg,1
"""
self.df = pd.read_csv(csv_file)
self.img_dir = img_dir
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# 获取文件名和标签
img_name = self.df.iloc[idx, 0] # filename 列
label = self.df.iloc[idx, 1] # label 列
# 构建完整路径并读取图像
img_path = os.path.join(self.img_dir, img_name)
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label
内存缓存优化
对于小数据集,可以将数据缓存到内存中加速访问:
class CachedDataset(Dataset):
"""带内存缓存的数据集"""
def __init__(self, samples, transform=None, cache_size=1000):
self.samples = samples
self.transform = transform
self.cache = {} # 缓存字典
self.cache_size = cache_size
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
# 检查缓存
if idx in self.cache:
data, label = self.cache[idx]
else:
# 从磁盘读取
data, label = self._load_sample(idx)
# 添加到缓存
if len(self.cache) < self.cache_size:
self.cache[idx] = (data, label)
if self.transform:
data = self.transform(data)
return data, label
def _load_sample(self, idx):
"""从磁盘加载样本"""
# 实际加载逻辑
pass
DataLoader 详解
DataLoader 是 Dataset 的包装器,提供了批量加载、数据打乱、多进程并行加载等功能。
基本使用
from torch.utils.data import DataLoader
# 创建 DataLoader
train_loader = DataLoader(
dataset=train_dataset, # 数据集对象
batch_size=32, # 每批样本数
shuffle=True, # 是否打乱数据
num_workers=4, # 数据加载进程数
pin_memory=True, # 固定内存(加速 GPU 传输)
drop_last=False # 是否丢弃不完整的最后一批
)
# 迭代获取批次数据
for batch_idx, (images, labels) in enumerate(train_loader):
print(f"批次 {batch_idx}:")
print(f" 图像形状: {images.shape}") # [32, C, H, W]
print(f" 标签形状: {labels.shape}") # [32]
# 训练步骤...
参数详解
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
dataset | Dataset | 必需 | 数据集对象 |
batch_size | int | 1 | 每批样本数量 |
shuffle | bool | False | 是否每个 epoch 打乱数据 |
sampler | Sampler | None | 自定义采样器,与 shuffle 互斥 |
batch_sampler | Sampler | None | 自定义批次采样器 |
num_workers | int | 0 | 数据加载子进程数,0 表示主进程加载 |
collate_fn | callable | None | 自定义批次整理函数 |
pin_memory | bool | False | 将数据固定在内存,加速 GPU 传输 |
drop_last | bool | False | 是否丢弃最后一个不完整的批次 |
timeout | float | 0 | 子进程获取数据的超时时间 |
worker_init_fn | callable | None | 子进程初始化函数 |
prefetch_factor | int | 2 | 每个工作进程预取的批次数量 |
persistent_workers | bool | False | 是否在工作进程间保持持久化 |
batch_size
批次大小影响训练速度和模型性能:
# 小批次:梯度估计更精确,但训练慢
loader = DataLoader(dataset, batch_size=8)
# 大批次:训练快,但可能降低泛化能力
loader = DataLoader(dataset, batch_size=256)
# 常见选择:根据 GPU 显存调整
# MNIST/CIFAR: 64-128
# ImageNet: 32-256
# 大模型: 可能需要更小
shuffle
训练时通常打乱数据以减少过拟合,验证和测试时不打乱:
# 训练集:打乱
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# 验证集/测试集:不打乱
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
num_workers
多进程并行加载数据可以显著提高数据加载速度:
# 单进程加载(调试时使用)
loader = DataLoader(dataset, batch_size=32, num_workers=0)
# 多进程加载
# 推荐:CPU 核心数 - 1 或 CPU 核心数
loader = DataLoader(dataset, batch_size=32, num_workers=4)
注意事项:
- Windows 上建议设置
num_workers=0,多进程可能有问题 - 过多的工作进程可能导致内存不足
- 使用多进程时,数据加载代码需要能正确序列化
pin_memory
当使用 GPU 训练时,pin_memory=True 可以加速数据从 CPU 到 GPU 的传输:
# GPU 训练时推荐启用
loader = DataLoader(dataset, batch_size=32, pin_memory=True)
# 数据自动从 pinned memory 传输到 GPU
for data, target in loader:
data = data.to(device, non_blocking=True) # 非阻塞传输
target = target.to(device, non_blocking=True)
Pinned memory(锁页内存)是 CPU 内存中不会被交换到磁盘的部分,GPU 可以直接通过 DMA(直接内存访问)从中读取数据,速度更快。
DataLoader 迭代方式
# 方式 1:for 循环(推荐)
for batch_idx, (data, target) in enumerate(train_loader):
# 处理批次数据
pass
# 方式 2:iter + next
data_iter = iter(train_loader)
first_batch = next(data_iter)
data, target = first_batch
# 方式 3:使用 range
num_batches = len(train_loader)
for i in range(num_batches):
batch = next(iter(train_loader)) # 不推荐,每次都创建新迭代器
获取 DataLoader 属性
# 批次数量
num_batches = len(train_loader)
# 每批样本数
batch_size = train_loader.batch_size
# 数据集大小
dataset_size = len(train_loader.dataset)
# 实际批次数量(考虑 drop_last)
if train_loader.drop_last:
actual_batches = dataset_size // batch_size
else:
actual_batches = (dataset_size + batch_size - 1) // batch_size
Sampler 采样器
Sampler 控制数据采样的顺序,DataLoader 的 shuffle 参数实际上是使用 RandomSampler 实现的。
内置采样器
from torch.utils.data import Sampler, SequentialSampler, RandomSampler
# SequentialSampler:顺序采样
sampler = SequentialSampler(dataset)
loader = DataLoader(dataset, batch_size=32, sampler=sampler)
# RandomSampler:随机采样
sampler = RandomSampler(dataset, replacement=False, num_samples=None)
loader = DataLoader(dataset, batch_size=32, sampler=sampler)
# 等价于 shuffle=True
loader = DataLoader(dataset, batch_size=32, shuffle=True)
WeightedRandomSampler:处理类别不平衡
当各类别样本数量差异较大时,使用 WeightedRandomSampler 可以让每个类别的采样概率更均衡:
from torch.utils.data import WeightedRandomSampler
import numpy as np
# 假设有 1000 个样本,类别 0 有 900 个,类别 1 有 100 个
labels = np.array([0] * 900 + [1] * 100)
# 计算每个类别的权重(与样本数量成反比)
class_counts = np.bincount(labels)
class_weights = 1.0 / class_counts # [0.00111, 0.01]
# 计算每个样本的权重
sample_weights = class_weights[labels] # 每个样本对应的权重
# 创建采样器
sampler = WeightedRandomSampler(
weights=torch.from_numpy(sample_weights).float(),
num_samples=len(labels), # 每个 epoch 采样的样本数
replacement=True # 是否允许重复采样
)
loader = DataLoader(dataset, batch_size=32, sampler=sampler)
工作原理:
每个样本被采样的概率与其权重成正比。如果类别 0 有 900 个样本,类别 1 有 100 个样本:
- 类别 0 的采样权重:1/900 ≈ 0.00111
- 类别 1 的采样权重:1/100 = 0.01
在 replacement=True 时,每个 epoch 中类别 0 和类别 1 的样本数量会趋于相等。
DistributedSampler:分布式训练
多 GPU 分布式训练时,需要确保每个 GPU 获取不同的数据子集:
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
# 初始化分布式环境
dist.init_process_group(backend='nccl')
local_rank = dist.get_rank()
world_size = dist.get_world_size()
# 创建分布式采样器
sampler = DistributedSampler(
dataset,
num_replicas=world_size, # 总 GPU 数
rank=local_rank, # 当前 GPU 编号
shuffle=True
)
train_loader = DataLoader(
dataset,
batch_size=32,
sampler=sampler,
num_workers=4,
pin_memory=True
)
# 每个 epoch 开始时需要设置 sampler 的 epoch
for epoch in range(num_epochs):
sampler.set_epoch(epoch) # 确保每个 epoch 的打乱不同
for batch in train_loader:
# 训练...
pass
自定义采样器
可以创建自定义采样器实现特定的采样策略:
from torch.utils.data import Sampler
class SubsetSampler(Sampler):
"""只采样指定索引的样本"""
def __init__(self, indices):
self.indices = indices
def __iter__(self):
return iter(self.indices)
def __len__(self):
return len(self.indices)
class BalancedBatchSampler(Sampler):
"""每个批次包含相同数量的正负样本"""
def __init__(self, labels, batch_size):
self.labels = labels
self.batch_size = batch_size
# 分别获取正负样本索引
self.pos_indices = np.where(labels == 1)[0].tolist()
self.neg_indices = np.where(labels == 0)[0].tolist()
self.num_batches = min(len(self.pos_indices), len(self.neg_indices)) // (batch_size // 2)
def __iter__(self):
# 打乱正负样本索引
np.random.shuffle(self.pos_indices)
np.random.shuffle(self.neg_indices)
for i in range(self.num_batches):
batch = []
# 添加正样本
batch.extend(self.pos_indices[i*self.batch_size//2:(i+1)*self.batch_size//2])
# 添加负样本
batch.extend(self.neg_indices[i*self.batch_size//2:(i+1)*self.batch_size//2])
np.random.shuffle(batch) # 打乱批次内顺序
yield batch
def __len__(self):
return self.num_batches
collate_fn 自定义批处理
当数据集中的样本长度不一致(如变长序列)或需要特殊的批处理逻辑时,可以使用 collate_fn 自定义批处理方式。
默认 collate 行为
默认的 collate_fn 会将样本堆叠成一个批次:
# 假设数据集返回的样本形状一致
# Dataset 返回:(C, H, W) 和标量标签
# DataLoader 返回:(batch_size, C, H, W) 和 (batch_size,) 标签
# 默认行为相当于:
def default_collate(batch):
# batch 是一个列表:[(data1, label1), (data2, label2), ...]
data, labels = zip(*batch)
return torch.stack(data, dim=0), torch.stack(labels, dim=0)
处理变长序列
对于 RNN/LSTM 处理的变长文本序列,需要进行填充(Padding):
import torch
from torch.nn.utils.rnn import pad_sequence
def collate_fn_variable_length(batch):
"""
处理变长序列的自定义 collate_fn
Args:
batch: [(seq1, label1), (seq2, label2), ...]
每个序列长度可能不同
Returns:
padded_seqs: (batch_size, max_seq_len, feature_dim)
lengths: (batch_size,) 每个序列的实际长度
labels: (batch_size,)
"""
# 分离序列和标签
sequences, labels = zip(*batch)
# 获取每个序列的长度
lengths = torch.tensor([len(seq) for seq in sequences])
# 填充序列到相同长度
# pad_sequence 需要 (seq_len, feature_dim) 形状
padded_seqs = pad_sequence(sequences, batch_first=True, padding_value=0)
# 堆叠标签
labels = torch.tensor(labels)
return padded_seqs, lengths, labels
# 使用示例
class TextDataset(Dataset):
def __init__(self, texts, labels):
self.texts = texts # 列表,每个元素是 tensor
self.labels = labels
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
return self.texts[idx], self.labels[idx]
# 变长文本数据
texts = [
torch.randn(5, 100), # 序列长度 5
torch.randn(8, 100), # 序列长度 8
torch.randn(3, 100), # 序列长度 3
]
labels = [0, 1, 0]
dataset = TextDataset(texts, labels)
loader = DataLoader(dataset, batch_size=3, collate_fn=collate_fn_variable_length)
padded_seqs, lengths, labels = next(iter(loader))
print(f"填充后形状: {padded_seqs.shape}") # (3, 8, 100)
print(f"序列长度: {lengths}") # tensor([5, 8, 3])
使用 pack_padded_sequence
对于 RNN/LSTM,使用 pack_padded_sequence 可以更高效地处理变长序列:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
def collate_fn_packed(batch):
"""返回 packed sequence 的 collate_fn"""
sequences, labels = zip(*batch)
lengths = torch.tensor([len(seq) for seq in sequences])
labels = torch.tensor(labels)
# 先填充
padded_seqs = pad_sequence(sequences, batch_first=True, padding_value=0)
# 按长度降序排序(pack_padded_sequence 要求)
lengths, sorted_idx = lengths.sort(descending=True)
padded_seqs = padded_seqs[sorted_idx]
labels = labels[sorted_idx]
# 打包
packed_seqs = pack_padded_sequence(padded_seqs, lengths, batch_first=True)
return packed_seqs, labels
# 在模型中使用
class RNNModel(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.rnn = nn.LSTM(input_size, hidden_size, batch_first=True)
def forward(self, packed_seqs):
# LSTM 可以直接处理 packed sequence
packed_output, (hidden, cell) = self.rnn(packed_seqs)
return packed_output, hidden
处理字典类型的样本
当 Dataset 返回字典而非元组时:
def collate_fn_dict(batch):
"""
处理字典类型的样本
Args:
batch: [{'image': img1, 'label': label1, 'mask': mask1}, ...]
"""
return {
'image': torch.stack([item['image'] for item in batch]),
'label': torch.tensor([item['label'] for item in batch]),
'mask': torch.stack([item['mask'] for item in batch]),
}
class DictDataset(Dataset):
def __getitem__(self, idx):
return {
'image': self.images[idx],
'label': self.labels[idx],
'mask': self.masks[idx],
}
数据变换 (Transforms)
数据变换是预处理数据的关键步骤,包括数据增强、归一化、格式转换等。
torchvision.transforms
torchvision.transforms 提供了丰富的图像变换:
from torchvision import transforms
# 组合多个变换
train_transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整大小
transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转
transforms.RandomRotation(15), # 随机旋转 ±15 度
transforms.ColorJitter( # 颜色抖动
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.1
),
transforms.ToTensor(), # 转为张量,归一化到 [0, 1]
transforms.Normalize( # 标准化
mean=[0.485, 0.456, 0.406], # ImageNet 均值
std=[0.229, 0.224, 0.225] # ImageNet 标准差
),
transforms.RandomErasing(p=0.5), # 随机擦除
])
# 验证/测试时通常不做数据增强
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
常用变换分类
几何变换
# 调整大小
transforms.Resize(256) # 短边调整为 256
transforms.Resize((224, 224)) # 固定大小
# 裁剪
transforms.CenterCrop(224) # 中心裁剪
transforms.RandomCrop(224) # 随机裁剪
transforms.RandomResizedCrop(224) # 随机裁剪并缩放
transforms.FiveCrop(224) # 五裁剪(四角 + 中心)
transforms.TenCrop(224) # 十裁剪(五裁剪 + 翻转)
# 翻转
transforms.RandomHorizontalFlip(p=0.5) # 随机水平翻转
transforms.RandomVerticalFlip(p=0.5) # 随机垂直翻转
# 旋转和仿射
transforms.RandomRotation(degrees=30) # 随机旋转
transforms.RandomAffine( # 随机仿射变换
degrees=10, # 旋转角度
translate=(0.1, 0.1), # 平移
scale=(0.9, 1.1), # 缩放
shear=10 # 剪切
)
transforms.RandomPerspective(distortion_scale=0.5) # 随机透视变换
颜色变换
# 亮度、对比度、饱和度、色调
transforms.ColorJitter(
brightness=0.2, # 亮度变化范围 [max(0, 1-brightness), 1+brightness]
contrast=0.2, # 对比度
saturation=0.2, # 饱和度
hue=0.1 # 色调变化范围 [-hue, hue]
)
# 灰度化
transforms.Grayscale(num_output_channels=1) # 转为单通道灰度
transforms.Grayscale(num_output_channels=3) # 保持三通道但为灰度
transforms.RandomGrayscale(p=0.1) # 随机灰度化
格式转换
# PIL Image 或 ndarray 转 Tensor
# 将像素值从 [0, 255] 归一化到 [0, 1]
# 将 HWC 格式转为 CHW 格式
transforms.ToTensor()
# Tensor 转 PIL Image
transforms.ToPILImage(mode='RGB')
# 标准化
# output = (input - mean) / std
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# 反标准化(用于可视化)
class UnNormalize:
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
return tensor
unnorm = UnNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
数据增强
# 随机擦除
transforms.RandomErasing(
p=0.5, # 擦除概率
scale=(0.02, 0.33), # 擦除区域面积占比范围
ratio=(0.3, 3.3), # 擦除区域长宽比范围
value=0 # 擦除填充值
)
# AutoAugment:自动增强策略
transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET)
# RandAugment:随机增强
transforms.RandAugment(num_ops=2, magnitude=9)
# TrivialAugmentWide:简单增强
transforms.TrivialAugmentWide()
# MixUp 和 CutMix 通常在训练循环中实现
自定义 Transform
创建自定义变换只需要实现 __call__ 方法:
class GaussianNoise:
"""添加高斯噪声"""
def __init__(self, mean=0.0, std=0.1):
self.mean = mean
self.std = std
def __call__(self, tensor):
noise = torch.randn(tensor.size()) * self.std + self.mean
return tensor + noise
class Cutout:
"""Cutout 数据增强"""
def __init__(self, n_holes=1, length=16):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
# img: Tensor of shape (C, H, W)
h, w = img.size(1), img.size(2)
mask = torch.ones((h, w), dtype=torch.float32)
for _ in range(self.n_holes):
y = torch.randint(h, (1,)).item()
x = torch.randint(w, (1,)).item()
y1 = max(0, y - self.length // 2)
y2 = min(h, y + self.length // 2)
x1 = max(0, x - self.length // 2)
x2 = min(w, x + self.length // 2)
mask[y1:y2, x1:x2] = 0.
mask = mask.expand_as(img)
return img * mask
# 使用自定义变换
transform = transforms.Compose([
transforms.ToTensor(),
GaussianNoise(std=0.05),
Cutout(n_holes=1, length=16),
])
在线增强 vs 离线增强
在线增强(Online Augmentation):在训练时实时进行增强,每次 epoch 图像都不同。
# 在线增强(推荐)
train_dataset = datasets.CIFAR10(
root='./data',
train=True,
transform=transforms.Compose([
transforms.RandomHorizontalFlip(), # 每次访问都随机
transforms.ToTensor(),
])
)
离线增强(Offline Augmentation):提前生成增强后的图像并保存,适合数据量很少的情况。
# 离线增强
import os
from PIL import Image
original_image = Image.open('original.jpg')
augmented_images = []
for i in range(10):
# 应用增强
aug_img = transforms.RandomRotation(30)(original_image)
aug_img.save(f'augmented_{i}.jpg')
完整示例:训练数据加载流程
下面是一个完整的图像分类训练数据加载示例:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 1. 定义数据变换
# 训练集:数据增强
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224), # 随机裁剪并缩放
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1), # 颜色抖动
transforms.ToTensor(),
transforms.Normalize( # ImageNet 归一化
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
transforms.RandomErasing(p=0.5), # 随机擦除
])
# 验证集:不做数据增强
val_transform = transforms.Compose([
transforms.Resize(256), # 先放大
transforms.CenterCrop(224), # 中心裁剪
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
# 2. 加载数据集
train_dataset = datasets.ImageFolder(
root='./data/train',
transform=train_transform
)
val_dataset = datasets.ImageFolder(
root='./data/val',
transform=val_transform
)
# 打印数据集信息
print(f"训练集大小: {len(train_dataset)}")
print(f"验证集大小: {len(val_dataset)}")
print(f"类别: {train_dataset.classes}")
# 3. 创建 DataLoader
train_loader = DataLoader(
train_dataset,
batch_size=64,
shuffle=True,
num_workers=4,
pin_memory=True,
drop_last=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=64,
shuffle=False,
num_workers=4,
pin_memory=True,
)
# 4. 获取设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 5. 训练循环中使用
def train_epoch(model, loader, criterion, optimizer, device):
model.train()
total_loss = 0.0
correct = 0
total = 0
for images, labels in loader:
# 移动到设备
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 统计
total_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
return total_loss / len(loader), 100. * correct / total
def validate(model, loader, criterion, device):
model.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in loader:
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
outputs = model(images)
loss = criterion(outputs, labels)
total_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
return total_loss / len(loader), 100. * correct / total
高级数据加载技巧
处理类别不平衡
除了 WeightedRandomSampler,还可以在损失函数中设置类别权重:
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
# 获取所有标签
labels = [train_dataset.targets[i] for i in range(len(train_dataset))]
# 计算类别权重
class_weights = compute_class_weight(
'balanced',
classes=np.unique(labels),
y=labels
)
class_weights = torch.tensor(class_weights, dtype=torch.float32)
# 使用加权交叉熵损失
criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
数据预取和缓存
使用 prefetch_factor 和 persistent_workers 优化数据加载:
loader = DataLoader(
dataset,
batch_size=64,
num_workers=4,
prefetch_factor=2, # 每个工作进程预取 2 个批次
persistent_workers=True, # 保持工作进程,避免重复启动
)
IterableDataset:流式数据
对于超大数据集无法全部加载到内存时,可以使用 IterableDataset:
from torch.utils.data import IterableDataset
class StreamDataset(IterableDataset):
"""流式数据集,适用于超大数据集"""
def __init__(self, file_path):
self.file_path = file_path
def __iter__(self):
# 打开文件并逐行读取
with open(self.file_path, 'r') as f:
for line in f:
# 处理每一行
data = self._process_line(line)
yield data
def _process_line(self, line):
# 解析数据
pass
# 使用方式相同
loader = DataLoader(stream_dataset, batch_size=32)
组合多个数据集
from torch.utils.data import ConcatDataset
# 合并多个数据集
dataset1 = datasets.CIFAR10(root='./data', train=True, transform=transform)
dataset2 = datasets.CIFAR100(root='./data', train=True, transform=transform)
combined_dataset = ConcatDataset([dataset1, dataset2])
print(f"合并后大小: {len(combined_dataset)}") # 100000 + 50000
loader = DataLoader(combined_dataset, batch_size=32, shuffle=True)
数据集划分
from torch.utils.data import random_split
# 划分训练集和验证集
full_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(
full_dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(42) # 固定随机种子
)
print(f"训练集大小: {len(train_dataset)}") # 40000
print(f"验证集大小: {len(val_dataset)}") # 10000
常见问题与解决方案
问题 1:DataLoader 迭代慢
原因:数据加载成为瓶颈
解决方案:
# 1. 增加 num_workers
loader = DataLoader(dataset, num_workers=8)
# 2. 启用 pin_memory
loader = DataLoader(dataset, pin_memory=True)
# 3. 启用持久化工作进程
loader = DataLoader(dataset, persistent_workers=True)
# 4. 增加预取
loader = DataLoader(dataset, prefetch_factor=4)
# 5. 使用更快的存储(SSD)
# 6. 减少数据变换复杂度
问题 2:GPU 显存不足
原因:batch_size 太大
解决方案:
# 1. 减小 batch_size
loader = DataLoader(dataset, batch_size=16)
# 2. 使用梯度累积
accumulation_steps = 4
for i, batch in enumerate(loader):
loss = model(batch) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
# 3. 及时清理缓存
torch.cuda.empty_cache()
问题 3:Windows 多进程问题
原因:Windows 使用 spawn 而非 fork 启动子进程
解决方案:
# Windows 上设置 num_workers=0 或使用 if __name__ == '__main__': 保护
if __name__ == '__main__':
loader = DataLoader(dataset, num_workers=4)
for batch in loader:
pass
# 或者
loader = DataLoader(dataset, num_workers=0) # 单进程
问题 4:数据损坏或读取失败
解决方案:
class SafeDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
self.corrupted = set()
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
if idx in self.corrupted:
# 返回一个替代样本
return self.dataset[0]
try:
return self.dataset[idx]
except Exception as e:
print(f"样本 {idx} 损坏: {e}")
self.corrupted.add(idx)
return self.dataset[0] # 返回替代样本
问题 5:内存不足
原因:数据集太大
解决方案:
# 1. 使用延迟加载(在 __getitem__ 中读取)
# 2. 使用 IterableDataset 流式处理
# 3. 使用内存映射文件
import numpy as np
class MMapDataset(Dataset):
def __init__(self, mmap_path):
# 内存映射,不加载到内存
self.data = np.load(mmap_path, mmap_mode='r')
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return torch.from_numpy(self.data[idx].copy())
小结
本章我们详细学习了 PyTorch 数据加载和处理的核心知识:
- Dataset:抽象数据集类,需要实现
__init__、__len__、__getitem__方法 - DataLoader:提供批量加载、打乱数据、多进程加载等功能
- Sampler:控制数据采样顺序,包括处理类别不平衡的 WeightedRandomSampler
- collate_fn:自定义批处理逻辑,处理变长序列等特殊情况
- Transforms:数据变换和数据增强,包括几何变换、颜色变换、格式转换等
- 高级技巧:内存缓存、流式处理、数据集划分、处理类别不平衡等
理解数据加载的原理和技巧,是构建高效训练流程的基础。下一章我们将学习完整的模型训练流程。