模型保存与加载
训练好的模型需要保存下来以便后续使用或部署。本章将详细介绍 PyTorch 中模型保存和加载的各种方法、最佳实践以及常见问题。
保存加载基础
什么是模型状态?
在 PyTorch 中,一个神经网络模型包含以下关键组成部分:
- 模型参数:权重和偏置,存储在
state_dict中 - 优化器状态:动量、学习率调度器状态等
- 模型架构:网络层的定义和连接方式
理解这些组成部分对于正确保存和加载模型至关重要。
state_dict 详解
state_dict 是一个 Python 字典,将每一层映射到其参数张量:
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
model = SimpleNet()
print("模型参数:")
for name, param in model.named_parameters():
print(f" {name}: {param.shape}")
print("\n模型 state_dict 的键:")
for key in model.state_dict().keys():
print(f" {key}: {model.state_dict()[key].shape}")
输出结果:
模型参数:
fc1.weight: torch.Size([256, 784])
fc1.bias: torch.Size([256])
fc2.weight: torch.Size([10, 256])
fc2.bias: torch.Size([10])
模型 state_dict 的键:
fc1.weight: torch.Size([256, 784])
fc1.bias: torch.Size([256])
fc2.weight: torch.Size([10, 256])
fc2.bias: torch.Size([10])
注意 state_dict 只包含可学习参数,不包含激活函数(如 ReLU),因为激活函数没有参数。
保存和加载模型参数
只保存模型权重
这是最常用的保存方式,只保存模型的参数:
import torch
import torch.nn as nn
import torch.optim as optim
model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
torch.save(model.state_dict(), 'model_weights.pth')
model = SimpleNet()
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
print("模型加载成功!")
保存完整检查点
训练过程中通常需要保存更多信息以便断点续训:
import torch
import torch.nn as nn
import torch.optim as optim
model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
epoch = 10
loss = 0.5
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': loss,
}
torch.save(checkpoint, 'checkpoint.pth')
model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print(f"从 epoch {epoch} 恢复,损失: {loss}")
保存多个模型
当需要保存多个模型时(如 GAN 的生成器和判别器):
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 784),
nn.Tanh()
)
def forward(self, x):
return self.model(x)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
generator = Generator()
discriminator = Discriminator()
torch.save({
'generator_state_dict': generator.state_dict(),
'discriminator_state_dict': discriminator.state_dict(),
}, 'gan_models.pth')
checkpoint = torch.load('gan_models.pth')
generator = Generator()
discriminator = Discriminator()
generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
保存完整模型
保存整个模型对象
可以保存整个模型对象,包括架构和参数:
import torch
import torch.nn as nn
model = SimpleNet()
torch.save(model, 'full_model.pth')
loaded_model = torch.load('full_model.pth')
loaded_model.eval()
print("完整模型加载成功!")
这种方式的优缺点
优点:
- 加载简单,不需要重新定义模型类
- 适合快速原型开发
缺点:
- 文件较大
- 依赖模型类的定义位置
- 可能存在版本兼容性问题
- 不推荐用于生产环境
import pickle
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
model = MyModel()
torch.save(model, 'model.pkl')
try:
loaded = torch.load('model.pkl')
except Exception as e:
print(f"加载失败: {e}")
跨设备保存加载
在 GPU 上训练,在 CPU 上加载
import torch
import torch.nn as nn
model = SimpleNet()
if torch.cuda.is_available():
model = model.cuda()
torch.save(model.state_dict(), 'gpu_model.pth')
model_cpu = SimpleNet()
model_cpu.load_state_dict(torch.load('gpu_model.pth', map_location='cpu'))
model_cpu.eval()
print("GPU 模型加载到 CPU 成功!")
在 CPU 上训练,在 GPU 上加载
model = SimpleNet()
torch.save(model.state_dict(), 'cpu_model.pth')
device = torch.device('cuda')
model_gpu = SimpleNet()
model_gpu.load_state_dict(torch.load('cpu_model.pth', map_location=device))
model_gpu = model_gpu.to(device)
print("CPU 模型加载到 GPU 成功!")
多 GPU 模型加载到单 GPU
当使用 DataParallel 训练的模型加载到单 GPU 时:
import torch.nn as nn
model = nn.DataParallel(SimpleNet())
if torch.cuda.is_available():
model = model.cuda()
torch.save(model.state_dict(), 'multi_gpu_model.pth')
model_single = SimpleNet()
state_dict = torch.load('multi_gpu_model.pth', map_location='cpu')
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace('module.', '')
new_state_dict[name] = v
model_single.load_state_dict(new_state_dict)
print("多 GPU 模型加载到单 GPU 成功!")
TorchScript 模型
什么是 TorchScript?
TorchScript 是 PyTorch 模型的中间表示,可以在没有 Python 环境的情况下运行,适合生产部署。
使用 torch.jit.trace
通过追踪模型执行来创建 TorchScript:
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
model = SimpleNet()
model.eval()
example_input = torch.randn(1, 784)
traced_model = torch.jit.trace(model, example_input)
traced_model.save('traced_model.pt')
loaded_model = torch.jit.load('traced_model.pt')
output = loaded_model(example_input)
print(f"输出形状: {output.shape}")
使用 torch.jit.script
通过脚本化模型来创建 TorchScript,支持控制流:
import torch
import torch.nn as nn
class ControlFlowNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x, use_relu=True):
x = self.fc1(x)
if use_relu:
x = torch.relu(x)
x = self.fc2(x)
return x
model = ControlFlowNet()
scripted_model = torch.jit.script(model)
scripted_model.save('scripted_model.pt')
loaded_model = torch.jit.load('scripted_model.pt')
output1 = loaded_model(torch.randn(1, 784), True)
output2 = loaded_model(torch.randn(1, 784), False)
print("脚本化模型加载成功!")
trace 与 script 的区别
| 特性 | torch.jit.trace | torch.jit.script |
|---|---|---|
| 控制流 | 不支持 | 支持 |
| 使用方式 | 追踪执行路径 | 解析代码 |
| 适用场景 | 固定输入形状 | 动态控制流 |
| 性能 | 略快 | 略慢 |
ONNX 导出
导出为 ONNX 格式
ONNX 是开放的模型格式,可以在不同框架间转换:
import torch
import torch.nn as nn
model = SimpleNet()
model.eval()
dummy_input = torch.randn(1, 784)
torch.onnx.export(
model,
dummy_input,
'model.onnx',
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
print("ONNX 模型导出成功!")
验证 ONNX 模型
import onnx
onnx_model = onnx.load('model.onnx')
onnx.checker.check_model(onnx_model)
print("ONNX 模型验证通过!")
import onnxruntime as ort
import numpy as np
ort_session = ort.InferenceSession('model.onnx')
ort_inputs = {ort_session.get_inputs()[0].name: np.random.randn(1, 784).astype(np.float32)}
ort_outputs = ort_session.run(None, ort_inputs)
print(f"ONNX Runtime 输出形状: {ort_outputs[0].shape}")
最佳实践
文件命名规范
import os
from datetime import datetime
def save_checkpoint(model, optimizer, epoch, loss, directory='checkpoints'):
os.makedirs(directory, exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'model_epoch_{epoch}_{timestamp}.pth'
filepath = os.path.join(directory, filename)
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'timestamp': timestamp,
}
torch.save(checkpoint, filepath)
print(f"检查点已保存: {filepath}")
return filepath
只保存最佳模型
import torch
import os
class ModelSaver:
def __init__(self, model, save_dir='checkpoints', mode='min'):
self.model = model
self.save_dir = save_dir
self.mode = mode
self.best_score = float('inf') if mode == 'min' else float('-inf')
os.makedirs(save_dir, exist_ok=True)
def __call__(self, score, epoch):
is_best = False
if self.mode == 'min':
is_best = score < self.best_score
else:
is_best = score > self.best_score
if is_best:
self.best_score = score
path = os.path.join(self.save_dir, 'best_model.pth')
torch.save(self.model.state_dict(), path)
print(f"新的最佳模型!Epoch {epoch}, Score: {score:.4f}")
path = os.path.join(self.save_dir, f'model_epoch_{epoch}.pth')
torch.save(self.model.state_dict(), path)
saver = ModelSaver(model, mode='min')
for epoch in range(100):
val_loss = train_one_epoch()
saver(val_loss, epoch)
模型版本管理
import torch
import json
import os
from datetime import datetime
def save_model_with_metadata(model, save_dir, metrics, hyperparams):
os.makedirs(save_dir, exist_ok=True)
version = len([f for f in os.listdir(save_dir) if f.startswith('v')]) + 1
version_dir = os.path.join(save_dir, f'v{version}')
os.makedirs(version_dir)
torch.save(model.state_dict(), os.path.join(version_dir, 'model.pth'))
metadata = {
'version': version,
'timestamp': datetime.now().isoformat(),
'metrics': metrics,
'hyperparams': hyperparams,
}
with open(os.path.join(version_dir, 'metadata.json'), 'w') as f:
json.dump(metadata, f, indent=2)
print(f"模型已保存: {version_dir}")
return version_dir
常见问题
问题 1:加载模型时键不匹配
原因:模型定义发生变化,或使用了 DataParallel
解决方案:
state_dict = torch.load('model.pth', map_location='cpu')
new_state_dict = {}
for k, v in state_dict.items():
name = k.replace('module.', '')
new_state_dict[name] = v
model.load_state_dict(new_state_dict, strict=True)
model.load_state_dict(state_dict, strict=False)
问题 2:加载后模型行为异常
原因:忘记调用 model.eval()
解决方案:
model.load_state_dict(torch.load('model.pth'))
model.eval()
with torch.no_grad():
output = model(input)
问题 3:保存的模型文件过大
原因:保存了不必要的梯度信息
解决方案:
model.eval()
for param in model.parameters():
param.requires_grad = False
torch.save(model.state_dict(), 'model.pth')
torch.save(model.state_dict(), 'model.pth', _use_new_zipfile_serialization=True)
问题 4:跨版本兼容性
原因:PyTorch 版本不同
解决方案:
checkpoint = torch.load('model.pth', map_location='cpu')
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
else:
state_dict = checkpoint
model.load_state_dict(state_dict, strict=False)
完整示例:训练与保存
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import os
class Trainer:
def __init__(self, model, save_dir='checkpoints'):
self.model = model
self.save_dir = save_dir
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.criterion = nn.CrossEntropyLoss()
self.optimizer = optim.Adam(model.parameters(), lr=0.001)
os.makedirs(save_dir, exist_ok=True)
self.best_loss = float('inf')
def train_epoch(self, dataloader):
self.model.train()
total_loss = 0
for batch_x, batch_y in dataloader:
batch_x = batch_x.to(self.device)
batch_y = batch_y.to(self.device)
self.optimizer.zero_grad()
outputs = self.model(batch_x)
loss = self.criterion(outputs, batch_y)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
def validate(self, dataloader):
self.model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_x, batch_y in dataloader:
batch_x = batch_x.to(self.device)
batch_y = batch_y.to(self.device)
outputs = self.model(batch_x)
loss = self.criterion(outputs, batch_y)
total_loss += loss.item()
_, predicted = outputs.max(1)
total += batch_y.size(0)
correct += predicted.eq(batch_y).sum().item()
return total_loss / len(dataloader), correct / total
def save_checkpoint(self, epoch, val_loss, is_best=False):
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'val_loss': val_loss,
}
path = os.path.join(self.save_dir, f'checkpoint_epoch_{epoch}.pth')
torch.save(checkpoint, path)
if is_best:
best_path = os.path.join(self.save_dir, 'best_model.pth')
torch.save(checkpoint, best_path)
print(f"保存最佳模型: val_loss={val_loss:.4f}")
def load_checkpoint(self, path):
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch'], checkpoint['val_loss']
x_train = torch.randn(1000, 784)
y_train = torch.randint(0, 10, (1000,))
x_val = torch.randn(200, 784)
y_val = torch.randint(0, 10, (200,))
train_loader = DataLoader(TensorDataset(x_train, y_train), batch_size=32, shuffle=True)
val_loader = DataLoader(TensorDataset(x_val, y_val), batch_size=32)
model = SimpleNet()
trainer = Trainer(model)
num_epochs = 10
for epoch in range(num_epochs):
train_loss = trainer.train_epoch(train_loader)
val_loss, val_acc = trainer.validate(val_loader)
is_best = val_loss < trainer.best_loss
if is_best:
trainer.best_loss = val_loss
trainer.save_checkpoint(epoch, val_loss, is_best)
print(f"Epoch {epoch+1}/{num_epochs}: "
f"Train Loss={train_loss:.4f}, "
f"Val Loss={val_loss:.4f}, "
f"Val Acc={val_acc:.4f}")
小结
本章我们学习了:
- state_dict:理解模型参数的存储结构
- 保存方式:只保存权重、保存完整检查点、保存完整模型
- 跨设备加载:GPU 与 CPU 之间的模型转换
- TorchScript:用于生产部署的模型格式
- ONNX 导出:跨框架模型转换
- 最佳实践:命名规范、版本管理、只保存最佳模型