跳到主要内容

模型保存与加载

训练好的模型需要保存以便后续使用。本章介绍 TensorFlow 中模型保存和加载的各种方式。

保存整个模型

保存为 Keras 格式

import tensorflow as tf
from tensorflow import keras

# 训练好的模型
model = keras.Sequential([
keras.layers.Dense(64, activation='relu', input_shape=(784,)),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

# 保存模型
model.save('my_model.keras')

# 加载模型
loaded_model = keras.models.load_model('my_model.keras')

# 验证
import numpy as np
x = np.random.randn(1, 784).astype(np.float32)
print(np.allclose(model.predict(x), loaded_model.predict(x))) # True

保存为 SavedModel 格式

# 保存为 SavedModel 格式
model.save('my_model')

# 加载模型
loaded_model = keras.models.load_model('my_model')

# SavedModel 也可以用 tf.saved_model 加载
loaded = tf.saved_model.load('my_model')

两种格式对比

特性Keras (.keras)SavedModel
文件结构单文件目录
跨平台TensorFlow/KerasTensorFlow 全平台
部署需要完整环境支持 TensorFlow Serving
自定义对象需要注册自动处理

只保存权重

如果只需要保存模型参数,可以只保存权重:

# 保存权重
model.save_weights('model_weights.weights.h5')

# 加载权重(需要先构建相同结构的模型)
model.load_weights('model_weights.weights.h5')

# 保存为 TensorFlow 格式
model.save_weights('model_weights')
model.load_weights('model_weights')

权重保存的应用场景

# 场景 1:训练中断后恢复
model = create_model()
model.load_weights('checkpoint.weights.h5') # 加载之前的权重
model.fit(x_train, y_train, epochs=10) # 继续训练

# 场景 2:迁移学习
base_model = keras.applications.ResNet50(weights='imagenet', include_top=False)
# ... 添加自定义层
model.load_weights('fine_tuned.weights.h5')

模型检查点

使用 ModelCheckpoint 回调在训练过程中自动保存模型:

# 每个 epoch 保存一次
checkpoint = keras.callbacks.ModelCheckpoint(
'model_epoch_{epoch}.keras',
save_freq='epoch'
)

# 只保存最佳模型
checkpoint_best = keras.callbacks.ModelCheckpoint(
'best_model.keras',
monitor='val_loss',
save_best_only=True,
mode='min'
)

# 只保存权重
checkpoint_weights = keras.callbacks.ModelCheckpoint(
'weights_epoch_{epoch}.weights.h5',
save_weights_only=True,
save_freq='epoch'
)

# 训练时使用
model.fit(
x_train, y_train,
epochs=10,
validation_data=(x_val, y_val),
callbacks=[checkpoint_best]
)

恢复训练

import os

# 检查是否有检查点
checkpoint_path = "training_checkpoint.weights.h5"

if os.path.exists(checkpoint_path):
model.load_weights(checkpoint_path)
print("从检查点恢复训练")
else:
print("从头开始训练")

# 训练并保存检查点
model.fit(
x_train, y_train,
epochs=10,
callbacks=[
keras.callbacks.ModelCheckpoint(
checkpoint_path,
save_weights_only=True,
save_freq='epoch'
)
]
)

保存模型架构

可以单独保存模型的结构:

# 获取模型配置
config = model.get_config()

# 从配置重建模型
new_model = keras.Model.from_config(config)

# 使用 JSON 格式
json_config = model.to_json()
with open('model_config.json', 'w') as f:
f.write(json_config)

# 从 JSON 加载
with open('model_config.json', 'r') as f:
json_config = f.read()
new_model = keras.models.model_from_json(json_config)

自定义对象的保存

如果模型包含自定义层或函数,需要注册:

# 自定义层
class CustomLayer(keras.layers.Layer):
def __init__(self, units=32, **kwargs):
super().__init__(**kwargs)
self.units = units

def build(self, input_shape):
self.w = self.add_weight(
shape=(input_shape[-1], self.units),
initializer='random_normal',
trainable=True
)

def call(self, inputs):
return tf.matmul(inputs, self.w)

# 创建模型
model = keras.Sequential([
CustomLayer(64),
keras.layers.Dense(10)
])

# 保存模型
model.save('custom_model.keras')

# 加载时注册自定义对象
with keras.utils.custom_object_scope({'CustomLayer': CustomLayer}):
loaded_model = keras.models.load_model('custom_model.keras')

# 或者使用字典方式
loaded_model = keras.models.load_model(
'custom_model.keras',
custom_objects={'CustomLayer': CustomLayer}
)

自定义损失函数

def custom_loss(y_true, y_pred):
return tf.reduce_mean(tf.square(y_true - y_pred))

model.compile(optimizer='adam', loss=custom_loss)
model.save('model_with_custom_loss.keras')

# 加载
loaded_model = keras.models.load_model(
'model_with_custom_loss.keras',
custom_objects={'custom_loss': custom_loss}
)

SavedModel 详细说明

SavedModel 是 TensorFlow 的标准保存格式,支持多种部署场景:

# 保存
tf.saved_model.save(model, 'saved_model_dir')

# 加载
loaded = tf.saved_model.load('saved_model_dir')

# 查看签名
print(list(loaded.signatures.keys())) # ['serving_default']

# 使用签名进行预测
infer = loaded.signatures['serving_default']
result = infer(input_tensor)

查看 SavedModel 结构

# 使用命令行工具
saved_model_cli show --dir saved_model_dir --all

导出为 TensorFlow Lite

# 转换为 TensorFlow Lite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# 保存
with open('model.tflite', 'wb') as f:
f.write(tflite_model)

导出为 TensorFlow.js

# 安装转换工具
pip install tensorflowjs

# 转换
tensorflowjs_converter --input_format keras model.keras tfjs_model

完整示例

import tensorflow as tf
from tensorflow import keras
import numpy as np
import os

# 创建模型
def create_model():
model = keras.Sequential([
keras.layers.Dense(128, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
return model

# 准备数据
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
x_test = x_test.reshape(-1, 784).astype('float32') / 255.0

# 创建检查点目录
checkpoint_dir = './training_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# 创建模型
model = create_model()

# 检查是否有保存的权重
latest = tf.train.latest_checkpoint(checkpoint_dir)
if latest:
model.load_weights(latest)
print(f"从 {latest} 恢复权重")

# 设置回调
callbacks = [
keras.callbacks.ModelCheckpoint(
filepath=os.path.join(checkpoint_dir, 'ckpt_epoch_{epoch}.weights.h5'),
save_weights_only=True,
save_freq='epoch'
),
keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=3,
restore_best_weights=True
)
]

# 训练
history = model.fit(
x_train, y_train,
epochs=10,
batch_size=64,
validation_split=0.1,
callbacks=callbacks
)

# 评估
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"测试准确率: {test_acc:.4f}")

# 保存最终模型
model.save('final_model.keras')

# 保存权重
model.save_weights('final_weights.weights.h5')

# 保存配置
config = model.get_config()

# 加载并验证
loaded_model = keras.models.load_model('final_model.keras')
test_loss, test_acc = loaded_model.evaluate(x_test, y_test)
print(f"加载后测试准确率: {test_acc:.4f}")

小结

本章介绍了 TensorFlow 中模型保存和加载的各种方式:

  1. 保存整个模型:使用 model.save() 保存完整模型
  2. 只保存权重:使用 save_weights() 只保存参数
  3. 模型检查点:训练过程中自动保存
  4. 自定义对象:注册自定义层和函数
  5. SavedModel 格式:用于生产部署

选择合适的保存方式取决于使用场景:

  • 开发调试:使用 Keras 格式
  • 生产部署:使用 SavedModel 格式
  • 迁移学习:只保存权重

下一章我们将创建一个实战项目,综合运用所学知识。