跳到主要内容

模型保存与加载

训练好的模型需要保存下来以便后续使用或部署。本章将详细介绍 PyTorch 中模型保存和加载的各种方法、最佳实践以及常见问题。

保存加载基础

什么是模型状态?

在 PyTorch 中,一个神经网络模型包含以下关键组成部分:

  • 模型参数:权重和偏置,存储在 state_dict
  • 优化器状态:动量、学习率调度器状态等
  • 模型架构:网络层的定义和连接方式

理解这些组成部分对于正确保存和加载模型至关重要。

state_dict 详解

state_dict 是一个 Python 字典,将每一层映射到其参数张量:

import torch
import torch.nn as nn

class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 10)

def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x

model = SimpleNet()

print("模型参数:")
for name, param in model.named_parameters():
print(f" {name}: {param.shape}")

print("\n模型 state_dict 的键:")
for key in model.state_dict().keys():
print(f" {key}: {model.state_dict()[key].shape}")

输出结果:

模型参数:
fc1.weight: torch.Size([256, 784])
fc1.bias: torch.Size([256])
fc2.weight: torch.Size([10, 256])
fc2.bias: torch.Size([10])

模型 state_dict 的键:
fc1.weight: torch.Size([256, 784])
fc1.bias: torch.Size([256])
fc2.weight: torch.Size([10, 256])
fc2.bias: torch.Size([10])

注意 state_dict 只包含可学习参数,不包含激活函数(如 ReLU),因为激活函数没有参数。

保存和加载模型参数

只保存模型权重

这是最常用的保存方式,只保存模型的参数:

import torch
import torch.nn as nn
import torch.optim as optim

model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)

torch.save(model.state_dict(), 'model_weights.pth')

model = SimpleNet()
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

print("模型加载成功!")

保存完整检查点

训练过程中通常需要保存更多信息以便断点续训:

import torch
import torch.nn as nn
import torch.optim as optim

model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

epoch = 10
loss = 0.5

checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': loss,
}

torch.save(checkpoint, 'checkpoint.pth')

model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

print(f"从 epoch {epoch} 恢复,损失: {loss}")

保存多个模型

当需要保存多个模型时(如 GAN 的生成器和判别器):

import torch
import torch.nn as nn

class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 784),
nn.Tanh()
)

def forward(self, x):
return self.model(x)

class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)

def forward(self, x):
return self.model(x)

generator = Generator()
discriminator = Discriminator()

torch.save({
'generator_state_dict': generator.state_dict(),
'discriminator_state_dict': discriminator.state_dict(),
}, 'gan_models.pth')

checkpoint = torch.load('gan_models.pth')
generator = Generator()
discriminator = Discriminator()
generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])

保存完整模型

保存整个模型对象

可以保存整个模型对象,包括架构和参数:

import torch
import torch.nn as nn

model = SimpleNet()

torch.save(model, 'full_model.pth')

loaded_model = torch.load('full_model.pth')
loaded_model.eval()

print("完整模型加载成功!")

这种方式的优缺点

优点

  • 加载简单,不需要重新定义模型类
  • 适合快速原型开发

缺点

  • 文件较大
  • 依赖模型类的定义位置
  • 可能存在版本兼容性问题
  • 不推荐用于生产环境
import pickle

class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 1)

def forward(self, x):
return self.fc(x)

model = MyModel()
torch.save(model, 'model.pkl')

try:
loaded = torch.load('model.pkl')
except Exception as e:
print(f"加载失败: {e}")

跨设备保存加载

在 GPU 上训练,在 CPU 上加载

import torch
import torch.nn as nn

model = SimpleNet()

if torch.cuda.is_available():
model = model.cuda()
torch.save(model.state_dict(), 'gpu_model.pth')

model_cpu = SimpleNet()
model_cpu.load_state_dict(torch.load('gpu_model.pth', map_location='cpu'))
model_cpu.eval()

print("GPU 模型加载到 CPU 成功!")

在 CPU 上训练,在 GPU 上加载

model = SimpleNet()
torch.save(model.state_dict(), 'cpu_model.pth')

device = torch.device('cuda')
model_gpu = SimpleNet()
model_gpu.load_state_dict(torch.load('cpu_model.pth', map_location=device))
model_gpu = model_gpu.to(device)

print("CPU 模型加载到 GPU 成功!")

多 GPU 模型加载到单 GPU

当使用 DataParallel 训练的模型加载到单 GPU 时:

import torch.nn as nn

model = nn.DataParallel(SimpleNet())
if torch.cuda.is_available():
model = model.cuda()
torch.save(model.state_dict(), 'multi_gpu_model.pth')

model_single = SimpleNet()

state_dict = torch.load('multi_gpu_model.pth', map_location='cpu')
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace('module.', '')
new_state_dict[name] = v

model_single.load_state_dict(new_state_dict)

print("多 GPU 模型加载到单 GPU 成功!")

TorchScript 模型

什么是 TorchScript?

TorchScript 是 PyTorch 模型的中间表示,可以在没有 Python 环境的情况下运行,适合生产部署。

使用 torch.jit.trace

通过追踪模型执行来创建 TorchScript:

import torch
import torch.nn as nn

class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 10)

def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x

model = SimpleNet()
model.eval()

example_input = torch.randn(1, 784)

traced_model = torch.jit.trace(model, example_input)

traced_model.save('traced_model.pt')

loaded_model = torch.jit.load('traced_model.pt')

output = loaded_model(example_input)
print(f"输出形状: {output.shape}")

使用 torch.jit.script

通过脚本化模型来创建 TorchScript,支持控制流:

import torch
import torch.nn as nn

class ControlFlowNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 10)

def forward(self, x, use_relu=True):
x = self.fc1(x)
if use_relu:
x = torch.relu(x)
x = self.fc2(x)
return x

model = ControlFlowNet()

scripted_model = torch.jit.script(model)

scripted_model.save('scripted_model.pt')

loaded_model = torch.jit.load('scripted_model.pt')

output1 = loaded_model(torch.randn(1, 784), True)
output2 = loaded_model(torch.randn(1, 784), False)
print("脚本化模型加载成功!")

trace 与 script 的区别

特性torch.jit.tracetorch.jit.script
控制流不支持支持
使用方式追踪执行路径解析代码
适用场景固定输入形状动态控制流
性能略快略慢

ONNX 导出

导出为 ONNX 格式

ONNX 是开放的模型格式,可以在不同框架间转换:

import torch
import torch.nn as nn

model = SimpleNet()
model.eval()

dummy_input = torch.randn(1, 784)

torch.onnx.export(
model,
dummy_input,
'model.onnx',
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)

print("ONNX 模型导出成功!")

验证 ONNX 模型

import onnx

onnx_model = onnx.load('model.onnx')
onnx.checker.check_model(onnx_model)
print("ONNX 模型验证通过!")

import onnxruntime as ort
import numpy as np

ort_session = ort.InferenceSession('model.onnx')

ort_inputs = {ort_session.get_inputs()[0].name: np.random.randn(1, 784).astype(np.float32)}
ort_outputs = ort_session.run(None, ort_inputs)

print(f"ONNX Runtime 输出形状: {ort_outputs[0].shape}")

最佳实践

文件命名规范

import os
from datetime import datetime

def save_checkpoint(model, optimizer, epoch, loss, directory='checkpoints'):
os.makedirs(directory, exist_ok=True)

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'model_epoch_{epoch}_{timestamp}.pth'
filepath = os.path.join(directory, filename)

checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'timestamp': timestamp,
}

torch.save(checkpoint, filepath)
print(f"检查点已保存: {filepath}")

return filepath

只保存最佳模型

import torch
import os

class ModelSaver:
def __init__(self, model, save_dir='checkpoints', mode='min'):
self.model = model
self.save_dir = save_dir
self.mode = mode
self.best_score = float('inf') if mode == 'min' else float('-inf')
os.makedirs(save_dir, exist_ok=True)

def __call__(self, score, epoch):
is_best = False
if self.mode == 'min':
is_best = score < self.best_score
else:
is_best = score > self.best_score

if is_best:
self.best_score = score
path = os.path.join(self.save_dir, 'best_model.pth')
torch.save(self.model.state_dict(), path)
print(f"新的最佳模型!Epoch {epoch}, Score: {score:.4f}")

path = os.path.join(self.save_dir, f'model_epoch_{epoch}.pth')
torch.save(self.model.state_dict(), path)

saver = ModelSaver(model, mode='min')
for epoch in range(100):
val_loss = train_one_epoch()
saver(val_loss, epoch)

模型版本管理

import torch
import json
import os
from datetime import datetime

def save_model_with_metadata(model, save_dir, metrics, hyperparams):
os.makedirs(save_dir, exist_ok=True)

version = len([f for f in os.listdir(save_dir) if f.startswith('v')]) + 1
version_dir = os.path.join(save_dir, f'v{version}')
os.makedirs(version_dir)

torch.save(model.state_dict(), os.path.join(version_dir, 'model.pth'))

metadata = {
'version': version,
'timestamp': datetime.now().isoformat(),
'metrics': metrics,
'hyperparams': hyperparams,
}

with open(os.path.join(version_dir, 'metadata.json'), 'w') as f:
json.dump(metadata, f, indent=2)

print(f"模型已保存: {version_dir}")
return version_dir

常见问题

问题 1:加载模型时键不匹配

原因:模型定义发生变化,或使用了 DataParallel

解决方案

state_dict = torch.load('model.pth', map_location='cpu')

new_state_dict = {}
for k, v in state_dict.items():
name = k.replace('module.', '')
new_state_dict[name] = v

model.load_state_dict(new_state_dict, strict=True)

model.load_state_dict(state_dict, strict=False)

问题 2:加载后模型行为异常

原因:忘记调用 model.eval()

解决方案

model.load_state_dict(torch.load('model.pth'))
model.eval()

with torch.no_grad():
output = model(input)

问题 3:保存的模型文件过大

原因:保存了不必要的梯度信息

解决方案

model.eval()
for param in model.parameters():
param.requires_grad = False

torch.save(model.state_dict(), 'model.pth')

torch.save(model.state_dict(), 'model.pth', _use_new_zipfile_serialization=True)

问题 4:跨版本兼容性

原因:PyTorch 版本不同

解决方案

checkpoint = torch.load('model.pth', map_location='cpu')

if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
else:
state_dict = checkpoint

model.load_state_dict(state_dict, strict=False)

完整示例:训练与保存

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import os

class Trainer:
def __init__(self, model, save_dir='checkpoints'):
self.model = model
self.save_dir = save_dir
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)

self.criterion = nn.CrossEntropyLoss()
self.optimizer = optim.Adam(model.parameters(), lr=0.001)

os.makedirs(save_dir, exist_ok=True)
self.best_loss = float('inf')

def train_epoch(self, dataloader):
self.model.train()
total_loss = 0

for batch_x, batch_y in dataloader:
batch_x = batch_x.to(self.device)
batch_y = batch_y.to(self.device)

self.optimizer.zero_grad()
outputs = self.model(batch_x)
loss = self.criterion(outputs, batch_y)
loss.backward()
self.optimizer.step()

total_loss += loss.item()

return total_loss / len(dataloader)

def validate(self, dataloader):
self.model.eval()
total_loss = 0
correct = 0
total = 0

with torch.no_grad():
for batch_x, batch_y in dataloader:
batch_x = batch_x.to(self.device)
batch_y = batch_y.to(self.device)

outputs = self.model(batch_x)
loss = self.criterion(outputs, batch_y)
total_loss += loss.item()

_, predicted = outputs.max(1)
total += batch_y.size(0)
correct += predicted.eq(batch_y).sum().item()

return total_loss / len(dataloader), correct / total

def save_checkpoint(self, epoch, val_loss, is_best=False):
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'val_loss': val_loss,
}

path = os.path.join(self.save_dir, f'checkpoint_epoch_{epoch}.pth')
torch.save(checkpoint, path)

if is_best:
best_path = os.path.join(self.save_dir, 'best_model.pth')
torch.save(checkpoint, best_path)
print(f"保存最佳模型: val_loss={val_loss:.4f}")

def load_checkpoint(self, path):
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch'], checkpoint['val_loss']

x_train = torch.randn(1000, 784)
y_train = torch.randint(0, 10, (1000,))
x_val = torch.randn(200, 784)
y_val = torch.randint(0, 10, (200,))

train_loader = DataLoader(TensorDataset(x_train, y_train), batch_size=32, shuffle=True)
val_loader = DataLoader(TensorDataset(x_val, y_val), batch_size=32)

model = SimpleNet()
trainer = Trainer(model)

num_epochs = 10
for epoch in range(num_epochs):
train_loss = trainer.train_epoch(train_loader)
val_loss, val_acc = trainer.validate(val_loader)

is_best = val_loss < trainer.best_loss
if is_best:
trainer.best_loss = val_loss

trainer.save_checkpoint(epoch, val_loss, is_best)

print(f"Epoch {epoch+1}/{num_epochs}: "
f"Train Loss={train_loss:.4f}, "
f"Val Loss={val_loss:.4f}, "
f"Val Acc={val_acc:.4f}")

小结

本章我们学习了:

  1. state_dict:理解模型参数的存储结构
  2. 保存方式:只保存权重、保存完整检查点、保存完整模型
  3. 跨设备加载:GPU 与 CPU 之间的模型转换
  4. TorchScript:用于生产部署的模型格式
  5. ONNX 导出:跨框架模型转换
  6. 最佳实践:命名规范、版本管理、只保存最佳模型

参考资源