模型保存与加载
训练好的模型需要保存以便后续使用。本章介绍 TensorFlow 中模型保存和加载的各种方式。
保存整个模型
保存为 Keras 格式
import tensorflow as tf
from tensorflow import keras
# 训练好的模型
model = keras.Sequential([
keras.layers.Dense(64, activation='relu', input_shape=(784,)),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
# 保存模型
model.save('my_model.keras')
# 加载模型
loaded_model = keras.models.load_model('my_model.keras')
# 验证
import numpy as np
x = np.random.randn(1, 784).astype(np.float32)
print(np.allclose(model.predict(x), loaded_model.predict(x))) # True
保存为 SavedModel 格式
# 保存为 SavedModel 格式
model.save('my_model')
# 加载模型
loaded_model = keras.models.load_model('my_model')
# SavedModel 也可以用 tf.saved_model 加载
loaded = tf.saved_model.load('my_model')
两种格式对比
| 特性 | Keras (.keras) | SavedModel |
|---|---|---|
| 文件结构 | 单文件 | 目录 |
| 跨平台 | TensorFlow/Keras | TensorFlow 全平台 |
| 部署 | 需要完整环境 | 支持 TensorFlow Serving |
| 自定义对象 | 需要注册 | 自动处理 |
只保存权重
如果只需要保存模型参数,可以只保存权重:
# 保存权重
model.save_weights('model_weights.weights.h5')
# 加载权重(需要先构建相同结构的模型)
model.load_weights('model_weights.weights.h5')
# 保存为 TensorFlow 格式
model.save_weights('model_weights')
model.load_weights('model_weights')
权重保存的应用场景
# 场景 1:训练中断后恢复
model = create_model()
model.load_weights('checkpoint.weights.h5') # 加载之前的权重
model.fit(x_train, y_train, epochs=10) # 继续训练
# 场景 2:迁移学习
base_model = keras.applications.ResNet50(weights='imagenet', include_top=False)
# ... 添加自定义层
model.load_weights('fine_tuned.weights.h5')
模型检查点
使用 ModelCheckpoint 回调在训练过程中自动保存模型:
# 每个 epoch 保存一次
checkpoint = keras.callbacks.ModelCheckpoint(
'model_epoch_{epoch}.keras',
save_freq='epoch'
)
# 只保存最佳模型
checkpoint_best = keras.callbacks.ModelCheckpoint(
'best_model.keras',
monitor='val_loss',
save_best_only=True,
mode='min'
)
# 只保存权重
checkpoint_weights = keras.callbacks.ModelCheckpoint(
'weights_epoch_{epoch}.weights.h5',
save_weights_only=True,
save_freq='epoch'
)
# 训练时使用
model.fit(
x_train, y_train,
epochs=10,
validation_data=(x_val, y_val),
callbacks=[checkpoint_best]
)
恢复训练
import os
# 检查是否有检查点
checkpoint_path = "training_checkpoint.weights.h5"
if os.path.exists(checkpoint_path):
model.load_weights(checkpoint_path)
print("从检查点恢复训练")
else:
print("从头开始训练")
# 训练并保存检查点
model.fit(
x_train, y_train,
epochs=10,
callbacks=[
keras.callbacks.ModelCheckpoint(
checkpoint_path,
save_weights_only=True,
save_freq='epoch'
)
]
)
保存模型架构
可以单独保存模型的结构:
# 获取模型配置
config = model.get_config()
# 从配置重建模型
new_model = keras.Model.from_config(config)
# 使用 JSON 格式
json_config = model.to_json()
with open('model_config.json', 'w') as f:
f.write(json_config)
# 从 JSON 加载
with open('model_config.json', 'r') as f:
json_config = f.read()
new_model = keras.models.model_from_json(json_config)
自定义对象的保存
如果模型包含自定义层或函数,需要注册:
# 自定义层
class CustomLayer(keras.layers.Layer):
def __init__(self, units=32, **kwargs):
super().__init__(**kwargs)
self.units = units
def build(self, input_shape):
self.w = self.add_weight(
shape=(input_shape[-1], self.units),
initializer='random_normal',
trainable=True
)
def call(self, inputs):
return tf.matmul(inputs, self.w)
# 创建模型
model = keras.Sequential([
CustomLayer(64),
keras.layers.Dense(10)
])
# 保存模型
model.save('custom_model.keras')
# 加载时注册自定义对象
with keras.utils.custom_object_scope({'CustomLayer': CustomLayer}):
loaded_model = keras.models.load_model('custom_model.keras')
# 或者使用字典方式
loaded_model = keras.models.load_model(
'custom_model.keras',
custom_objects={'CustomLayer': CustomLayer}
)
自定义损失函数
def custom_loss(y_true, y_pred):
return tf.reduce_mean(tf.square(y_true - y_pred))
model.compile(optimizer='adam', loss=custom_loss)
model.save('model_with_custom_loss.keras')
# 加载
loaded_model = keras.models.load_model(
'model_with_custom_loss.keras',
custom_objects={'custom_loss': custom_loss}
)
SavedModel 详细说明
SavedModel 是 TensorFlow 的标准保存格式,支持多种部署场景:
# 保存
tf.saved_model.save(model, 'saved_model_dir')
# 加载
loaded = tf.saved_model.load('saved_model_dir')
# 查看签名
print(list(loaded.signatures.keys())) # ['serving_default']
# 使用签名进行预测
infer = loaded.signatures['serving_default']
result = infer(input_tensor)
查看 SavedModel 结构
# 使用命令行工具
saved_model_cli show --dir saved_model_dir --all
导出为 TensorFlow Lite
# 转换为 TensorFlow Lite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
# 保存
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
导出为 TensorFlow.js
# 安装转换工具
pip install tensorflowjs
# 转换
tensorflowjs_converter --input_format keras model.keras tfjs_model
完整示例
import tensorflow as tf
from tensorflow import keras
import numpy as np
import os
# 创建模型
def create_model():
model = keras.Sequential([
keras.layers.Dense(128, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
return model
# 准备数据
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
x_test = x_test.reshape(-1, 784).astype('float32') / 255.0
# 创建检查点目录
checkpoint_dir = './training_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
# 创建模型
model = create_model()
# 检查是否有保存的权重
latest = tf.train.latest_checkpoint(checkpoint_dir)
if latest:
model.load_weights(latest)
print(f"从 {latest} 恢复权重")
# 设置回调
callbacks = [
keras.callbacks.ModelCheckpoint(
filepath=os.path.join(checkpoint_dir, 'ckpt_epoch_{epoch}.weights.h5'),
save_weights_only=True,
save_freq='epoch'
),
keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=3,
restore_best_weights=True
)
]
# 训练
history = model.fit(
x_train, y_train,
epochs=10,
batch_size=64,
validation_split=0.1,
callbacks=callbacks
)
# 评估
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"测试准确率: {test_acc:.4f}")
# 保存最终模型
model.save('final_model.keras')
# 保存权重
model.save_weights('final_weights.weights.h5')
# 保存配置
config = model.get_config()
# 加载并验证
loaded_model = keras.models.load_model('final_model.keras')
test_loss, test_acc = loaded_model.evaluate(x_test, y_test)
print(f"加载后测试准确率: {test_acc:.4f}")
小结
本章介绍了 TensorFlow 中模型保存和加载的各种方式:
- 保存整个模型:使用
model.save()保存完整模型 - 只保存权重:使用
save_weights()只保存参数 - 模型检查点:训练过程中自动保存
- 自定义对象:注册自定义层和函数
- SavedModel 格式:用于生产部署
选择合适的保存方式取决于使用场景:
- 开发调试:使用 Keras 格式
- 生产部署:使用 SavedModel 格式
- 迁移学习:只保存权重
下一章我们将创建一个实战项目,综合运用所学知识。