数据处理与 tf.data
高效的数据处理是深度学习训练的关键环节。TensorFlow 提供了 tf.data API 来构建灵活高效的数据管道,能够处理大规模数据集并充分利用硬件资源。
tf.data.Dataset 基础
tf.data.Dataset 是 TensorFlow 中表示数据集的核心抽象,它代表了一个元素序列,每个元素可以是一个或多个张量。
创建数据集
从张量创建数据集是最基本的方式:
import tensorflow as tf
import numpy as np
# 从单个张量创建
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5])
for element in dataset:
print(element.numpy())
# 输出: 1, 2, 3, 4, 5
# 从多个张量创建(特征和标签配对)
features = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
labels = np.array([0, 1, 0, 1, 0])
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
for feature, label in dataset:
print(f"特征: {feature.numpy()}, 标签: {label.numpy()}")
# 从字典创建(适合结构化数据)
data = {
'features': np.random.randn(100, 10).astype(np.float32),
'labels': np.random.randint(0, 2, 100)
}
dataset = tf.data.Dataset.from_tensor_slices(data)
从 NumPy 数组创建
实际项目中,数据通常以 NumPy 数组形式存储:
# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# 创建数据集
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
print(f"训练集大小: {len(train_dataset)}") # 60000
print(f"测试集大小: {len(test_dataset)}") # 10000
从生成器创建
当数据无法一次性加载到内存时,可以使用生成器:
def data_generator():
for i in range(100):
yield np.random.randn(10).astype(np.float32), i % 2
dataset = tf.data.Dataset.from_generator(
data_generator,
output_signature=(
tf.TensorSpec(shape=(10,), dtype=tf.float32),
tf.TensorSpec(shape=(), dtype=tf.int32)
)
)
for features, label in dataset.take(3):
print(f"特征形状: {features.shape}, 标签: {label.numpy()}")
从 TFRecord 文件创建
TFRecord 是 TensorFlow 推荐的高效存储格式,特别适合大规模数据集:
# 写入 TFRecord 文件
def serialize_example(feature, label):
feature_dict = {
'feature': tf.train.Feature(float_list=tf.train.FloatList(value=feature)),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}
example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
return example.SerializeToString()
# 创建 TFRecord 文件
with tf.io.TFRecordWriter('data.tfrecord') as writer:
for i in range(100):
feature = np.random.randn(10).astype(np.float32)
label = i % 2
writer.write(serialize_example(feature, label))
# 读取 TFRecord 文件
feature_description = {
'feature': tf.io.FixedLenFeature([10], tf.float32),
'label': tf.io.FixedLenFeature([], tf.int64)
}
def parse_example(example_proto):
return tf.io.parse_single_example(example_proto, feature_description)
dataset = tf.data.TFRecordDataset('data.tfrecord')
dataset = dataset.map(parse_example)
for example in dataset.take(2):
print(f"特征: {example['feature'].numpy()}")
print(f"标签: {example['label'].numpy()}")
数据集转换操作
tf.data.Dataset 提供了丰富的转换操作,可以灵活地处理数据。
batch - 分批处理
将数据分成批次是训练神经网络的基本操作:
dataset = tf.data.Dataset.from_tensor_slices(
np.arange(20).reshape(20, 1)
)
# 分成每批 5 个样本
batched = dataset.batch(5)
for i, batch in enumerate(batched):
print(f"批次 {i}: {batch.numpy().flatten()}")
# 批次 0: [0 1 2 3 4]
# 批次 1: [5 6 7 8 9]
# ...
# 不完整的批次处理
dataset = tf.data.Dataset.from_tensor_slices(np.arange(7))
batched = dataset.batch(3, drop_remainder=False) # 保留不完整批次
# [0, 1, 2], [3, 4, 5], [6]
batched = dataset.batch(3, drop_remainder=True) # 丢弃不完整批次
# [0, 1, 2], [3, 4, 5]
shuffle - 打乱数据
打乱数据可以避免模型学习到数据的顺序模式:
dataset = tf.data.Dataset.from_tensor_slices(np.arange(10))
# buffer_size 决定打乱程度
# buffer_size 越大,打乱越彻底,但内存占用越高
shuffled = dataset.shuffle(buffer_size=10)
print("打乱后:", list(shuffled.as_numpy_iterator()))
# 对于大数据集,buffer_size 通常设为批次数量的若干倍
shuffled = dataset.shuffle(buffer_size=1000, seed=42, reshuffle_each_iteration=True)
shuffle 的工作原理:从缓冲区中随机抽取元素,然后从数据源补充新元素到缓冲区。buffer_size 设为数据集大小时可以实现完全随机打乱。
map - 数据预处理
map 方法可以对每个元素应用自定义函数,是数据预处理的核心操作:
# 加载图像数据
image_paths = ['img1.jpg', 'img2.jpg', 'img3.jpg']
labels = [0, 1, 0]
dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))
def preprocess_image(path, label):
# 读取图像文件
image = tf.io.read_file(path)
# 解码图像
image = tf.image.decode_jpeg(image, channels=3)
# 调整大小
image = tf.image.resize(image, [224, 224])
# 归一化
image = tf.cast(image, tf.float32) / 255.0
return image, label
# 应用预处理
dataset = dataset.map(preprocess_image)
# 并行处理加速
dataset = dataset.map(
preprocess_image,
num_parallel_calls=tf.data.AUTOTUNE # 自动选择并行数
)
对于 MNIST 数据集的完整预处理示例:
def preprocess_mnist(x, y):
# 转换数据类型并归一化
x = tf.cast(x, tf.float32) / 255.0
# 添加通道维度
x = tf.expand_dims(x, axis=-1)
return x, y
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.map(preprocess_mnist, num_parallel_calls=tf.data.AUTOTUNE)
filter - 过滤数据
dataset = tf.data.Dataset.from_tensor_slices(np.arange(10))
# 只保留偶数
filtered = dataset.filter(lambda x: x % 2 == 0)
print(list(filtered.as_numpy_iterator())) # [0, 2, 4, 6, 8]
# 过滤无效数据
def is_valid(features, label):
return tf.reduce_all(tf.math.is_finite(features))
dataset = dataset.filter(is_valid)
take 和 skip - 截取和跳过
dataset = tf.data.Dataset.from_tensor_slices(np.arange(100))
# 取前 10 个元素
first_10 = dataset.take(10)
print("前 10 个:", list(first_10.as_numpy_iterator()))
# 跳过前 10 个,取后面的
rest = dataset.skip(10).take(10)
print("第 11-20 个:", list(rest.as_numpy_iterator()))
# 分割训练集和验证集
train_data = dataset.take(80)
val_data = dataset.skip(80)
repeat - 重复数据
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
# 重复 3 次
repeated = dataset.repeat(3)
print(list(repeated.as_numpy_iterator())) # [1, 2, 3, 1, 2, 3, 1, 2, 3]
# 无限重复(训练时常用)
infinite = dataset.repeat()
prefetch - 预取数据
预取可以让数据准备和模型训练并行进行,显著提升训练效率:
dataset = tf.data.Dataset.from_tensor_slices(np.arange(100))
# 预取一个批次的数据
dataset = dataset.batch(32).prefetch(tf.data.AUTOTUNE)
prefetch 的工作原理:当 GPU 在训练当前批次时,CPU 同时准备下一个批次的数据。AUTOTUNE 参数让 TensorFlow 自动选择最优的预取数量。
构建高效数据管道
将上述操作组合起来,可以构建高效的数据管道:
def create_data_pipeline(x, y, batch_size=32, shuffle=True, training=True):
"""
创建高效的数据管道
参数:
x: 特征数据
y: 标签数据
batch_size: 批次大小
shuffle: 是否打乱数据
training: 是否为训练模式
"""
dataset = tf.data.Dataset.from_tensor_slices((x, y))
if training and shuffle:
# 打乱数据,buffer_size 通常设为批次数量的若干倍
dataset = dataset.shuffle(buffer_size=10000)
# 数据预处理
def preprocess(x, y):
x = tf.cast(x, tf.float32) / 255.0
x = tf.expand_dims(x, axis=-1)
return x, y
dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
# 分批
dataset = dataset.batch(batch_size)
if training:
# 训练时重复数据
dataset = dataset.repeat()
# 预取
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
# 使用示例
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
train_dataset = create_data_pipeline(x_train, y_train, batch_size=64, training=True)
test_dataset = create_data_pipeline(x_test, y_test, batch_size=64, training=False)
# 训练模型
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# 训练时需要指定 steps_per_epoch
model.fit(
train_dataset,
epochs=5,
steps_per_epoch=len(x_train) // 64,
validation_data=test_dataset,
validation_steps=len(x_test) // 64
)
数据增强
图像数据增强是提高模型泛化能力的有效手段:
使用 Keras 预处理层
from tensorflow.keras import layers
# 定义数据增强层
data_augmentation = tf.keras.Sequential([
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
layers.RandomZoom(0.1),
layers.RandomContrast(0.1),
])
# 在数据管道中应用
def augment_image(image, label):
image = data_augmentation(image, training=True)
return image, label
# 创建数据集
(x_train, y_train), _ = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(50000)
train_dataset = train_dataset.map(augment_image, num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.batch(64).prefetch(tf.data.AUTOTUNE)
自定义数据增强函数
def custom_augment(image, label):
# 随机水平翻转
image = tf.image.random_flip_left_right(image)
# 随机调整亮度
image = tf.image.random_brightness(image, max_delta=0.2)
# 随机调整对比度
image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
# 随机裁剪后调整回原尺寸
image = tf.image.random_crop(image, size=[24, 24, 3])
image = tf.image.resize(image, [32, 32])
# 确保像素值在 [0, 1] 范围
image = tf.clip_by_value(image, 0.0, 1.0)
return image, label
处理文本数据
文本数据需要特殊的预处理流程:
import tensorflow as tf
from tensorflow.keras.layers import TextVectorization
# 示例文本数据
texts = [
"这是一个很好的电影",
"服务态度太差了",
"产品质量不错",
"再也不买了",
]
labels = [1, 0, 1, 0] # 1: 正面, 0: 负面
# 创建文本向量化层
vectorize_layer = TextVectorization(
max_tokens=1000,
output_mode='int',
output_sequence_length=20
)
# 适配词汇表
vectorize_layer.adapt(texts)
# 创建数据集
dataset = tf.data.Dataset.from_tensor_slices((texts, labels))
def preprocess_text(text, label):
text = vectorize_layer(text)
return text, label
dataset = dataset.map(preprocess_text)
dataset = dataset.batch(2).prefetch(tf.data.AUTOTUNE)
# 构建模型
model = tf.keras.Sequential([
tf.keras.layers.Embedding(1000, 64),
tf.keras.layers.GlobalAveragePooling1D(),
tf.keras.layers.Dense(16, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
处理时序数据
时序数据需要创建滑动窗口:
def create_time_series_dataset(data, window_size, batch_size, shuffle=True):
"""
创建时序数据集
参数:
data: 时序数据 (numpy array)
window_size: 窗口大小
batch_size: 批次大小
shuffle: 是否打乱
"""
dataset = tf.data.Dataset.from_tensor_slices(data)
# 创建滑动窗口
dataset = dataset.window(window_size + 1, shift=1, drop_remainder=True)
# 将窗口展平为批次
dataset = dataset.flat_map(lambda window: window.batch(window_size + 1))
# 分割特征和标签
dataset = dataset.map(lambda window: (window[:-1], window[-1]))
if shuffle:
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
return dataset
# 示例:预测下一个时间点的值
import numpy as np
# 生成模拟时序数据
time = np.arange(1000)
data = np.sin(time * 0.1) + np.random.randn(1000) * 0.1
train_dataset = create_time_series_dataset(data, window_size=20, batch_size=32)
# 构建 LSTM 模型
model = tf.keras.Sequential([
tf.keras.layers.LSTM(32, return_sequences=True, input_shape=(20, 1)),
tf.keras.layers.LSTM(16),
tf.keras.layers.Dense(1)
])
model.compile(optimizer='adam', loss='mse')
model.fit(train_dataset, epochs=10)
性能优化技巧
数据管道性能优化清单
def optimized_pipeline(x, y, batch_size=32):
dataset = tf.data.Dataset.from_tensor_slices((x, y))
# 1. 打乱数据
dataset = dataset.shuffle(10000)
# 2. 并行预处理
dataset = dataset.map(
preprocess_fn,
num_parallel_calls=tf.data.AUTOTUNE
)
# 3. 缓存预处理结果(如果内存允许)
dataset = dataset.cache()
# 4. 分批
dataset = dataset.batch(batch_size)
# 5. 预取
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
使用 cache 提升性能
对于需要多次遍历的数据集,缓存可以避免重复计算:
# 缓存到内存
dataset = dataset.cache()
# 缓存到文件(适合大数据集)
dataset = dataset.cache('/path/to/cache')
# 注意:cache 应该在 shuffle 之前,否则每次 epoch 数据顺序相同
# 正确顺序:map -> cache -> shuffle -> batch -> prefetch
性能对比
import time
def benchmark(dataset, num_epochs=2):
start_time = time.time()
for epoch in range(num_epochs):
for batch in dataset:
pass
return time.time() - start_time
# 无优化
dataset_naive = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset_naive = dataset_naive.batch(32)
# 优化后
dataset_optimized = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset_optimized = dataset_optimized.shuffle(10000)
dataset_optimized = dataset_optimized.batch(32)
dataset_optimized = dataset_optimized.cache()
dataset_optimized = dataset_optimized.prefetch(tf.data.AUTOTUNE)
print(f"无优化: {benchmark(dataset_naive):.2f}秒")
print(f"优化后: {benchmark(dataset_optimized):.2f}秒")
常见问题与解决方案
内存不足
# 使用生成器而非加载全部数据
def large_data_generator(file_list):
for file_path in file_list:
data = np.load(file_path)
yield data
dataset = tf.data.Dataset.from_generator(
lambda: large_data_generator(file_list),
output_signature=tf.TensorSpec(shape=(None, 10), dtype=tf.float32)
)
数据不平衡
# 过采样少数类
def oversample(dataset, target_count):
# 统计各类别数量
# 对少数类进行重复采样
pass
# 或者在训练时使用类别权重
class_weights = {0: 1.0, 1: 5.0} # 少数类权重更高
model.fit(dataset, class_weight=class_weights)
调试数据管道
# 检查数据形状
for batch in dataset.take(1):
print(f"批次形状: {batch[0].shape}")
print(f"标签形状: {batch[1].shape}")
# 可视化数据
import matplotlib.pyplot as plt
for images, labels in dataset.take(1):
fig, axes = plt.subplots(4, 8, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
ax.imshow(images[i].numpy().squeeze(), cmap='gray')
ax.set_title(labels[i].numpy())
ax.axis('off')
plt.tight_layout()
plt.show()
小结
本章介绍了 TensorFlow 数据处理的核心内容:
- 创建数据集:from_tensor_slices、from_generator、TFRecord
- 数据转换:batch、shuffle、map、filter、take、skip、repeat
- 性能优化:prefetch、cache、并行处理
- 数据增强:图像增强、文本处理、时序数据
- 最佳实践:合理的操作顺序、内存管理、调试技巧
高效的数据管道是深度学习训练的基础,掌握 tf.data API 对于构建生产级应用至关重要。下一章我们将学习循环神经网络(RNN/LSTM)处理序列数据。