模型部署
训练好的深度学习模型最终需要部署到生产环境中。本章将介绍 PyTorch 模型的各种部署方式,包括 TorchScript、ONNX、TensorRT 以及移动端部署。
部署概述
部署方式对比
| 部署方式 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|
| PyTorch 原生 | 简单直接 | 依赖 Python | 研究、原型开发 |
| TorchScript | 无需 Python | 调试困难 | 服务端部署 |
| ONNX | 跨框架 | 兼容性问题 | 跨平台部署 |
| TensorRT | 高性能 | 仅限 NVIDIA GPU | 高性能推理 |
| 移动端部署 | 轻量级 | 功能受限 | 移动应用 |
部署流程
训练模型 → 导出模型 → 优化模型 → 部署服务
↓ ↓ ↓ ↓
PyTorch TorchScript 量化/剪枝 服务端/移动端
ONNX TensorRT
TorchScript 部署
使用 torch.jit.trace
trace 通过记录实际执行的操作来创建模型:
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
model = SimpleModel()
model.eval()
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)
print("Traced 模型结构:")
print(traced_model.graph)
traced_model.save('traced_model.pt')
loaded_model = torch.jit.load('traced_model.pt')
output = loaded_model(example_input)
print(f"输出形状: {output.shape}")
使用 torch.jit.script
script 通过解析 Python 代码来创建模型,支持控制流:
import torch
import torch.nn as nn
class ScriptableModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor, use_relu: bool = True) -> torch.Tensor:
x = self.fc1(x)
if use_relu:
x = self.relu(x)
x = self.fc2(x)
return x
model = ScriptableModel(784, 256, 10)
model.eval()
scripted_model = torch.jit.script(model)
print("Scripted 模型结构:")
print(scripted_model.code)
scripted_model.save('scripted_model.pt')
loaded_model = torch.jit.load('scripted_model.pt')
x = torch.randn(1, 784)
output1 = loaded_model(x, True)
output2 = loaded_model(x, False)
print(f"使用 ReLU 输出: {output1.shape}")
print(f"不使用 ReLU 输出: {output2.shape}")
trace 与 script 的选择
| 特性 | trace | script |
|---|---|---|
| 控制流 | 不支持 | 支持 |
| 动态形状 | 有限支持 | 支持 |
| 使用难度 | 简单 | 中等 |
| 调试 | 困难 | 较容易 |
| 适用场景 | 固定计算图 | 动态计算图 |
混合使用 trace 和 script
import torch
import torch.nn as nn
class FeatureExtractor(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
self.relu = nn.ReLU()
self.pool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.pool(x)
return x.view(x.size(0), -1)
class Classifier(nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.fc = nn.Linear(num_features, num_classes)
def forward(self, x, return_probs: bool = True):
logits = self.fc(x)
if return_probs:
return torch.softmax(logits, dim=-1)
return logits
class HybridModel(nn.Module):
def __init__(self):
super().__init__()
self.features = torch.jit.trace(FeatureExtractor(), torch.randn(1, 3, 224, 224))
self.classifier = torch.jit.script(Classifier(128, 10))
def forward(self, x, return_probs: bool = True):
features = self.features(x)
return self.classifier(features, return_probs)
model = HybridModel()
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save('hybrid_model.pt')
loaded = torch.jit.load('hybrid_model.pt')
output = loaded(torch.randn(1, 3, 224, 224), True)
print(f"输出形状: {output.shape}")
ONNX 导出
基本导出
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(128, 10)
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
model = SimpleModel()
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
'model.onnx',
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
print("ONNX 模型导出成功!")
验证 ONNX 模型
import onnx
onnx_model = onnx.load('model.onnx')
onnx.checker.check_model(onnx_model)
print("ONNX 模型验证通过!")
print("\n模型输入:")
for input in onnx_model.graph.input:
print(f" 名称: {input.name}")
print(f" 类型: {input.type}")
print("\n模型输出:")
for output in onnx_model.graph.output:
print(f" 名称: {output.name}")
ONNX Runtime 推理
import onnxruntime as ort
import numpy as np
import torch
ort_session = ort.InferenceSession('model.onnx')
print("ONNX Runtime 输入信息:")
for input in ort_session.get_inputs():
print(f" 名称: {input.name}, 形状: {input.shape}, 类型: {input.type}")
print("\nONNX Runtime 输出信息:")
for output in ort_session.get_outputs():
print(f" 名称: {output.name}, 形状: {output.shape}, 类型: {output.type}")
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
ort_inputs = {ort_session.get_inputs()[0].name: input_data}
ort_outputs = ort_session.run(None, ort_inputs)
print(f"\nONNX Runtime 输出形状: {ort_outputs[0].shape}")
model = SimpleModel()
model.eval()
torch_output = model(torch.from_numpy(input_data))
print(f"PyTorch 输出形状: {torch_output.shape}")
diff = np.abs(ort_outputs[0] - torch_output.detach().numpy()).max()
print(f"最大差异: {diff}")
模型量化
动态量化
动态量化在推理时动态地将权重从浮点数量化为整数:
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
model = SimpleModel()
model.eval()
quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.Linear},
dtype=torch.qint8
)
print("原始模型:")
print(model)
print("\n量化模型:")
print(quantized_model)
def count_parameters(model):
return sum(p.numel() for p in model.parameters())
print(f"\n原始模型参数量: {count_parameters(model):,}")
x = torch.randn(1, 784)
output = quantized_model(x)
print(f"量化模型输出形状: {output.shape}")
静态量化
静态量化需要校准数据来确定量化参数:
import torch
import torch.nn as nn
class QuantizableModel(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.quantization.QuantStub()
self.fc1 = nn.Linear(784, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.dequant(x)
return x
model = QuantizableModel()
model.eval()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)
print("校准模型...")
for _ in range(100):
model(torch.randn(1, 784))
torch.quantization.convert(model, inplace=True)
print("静态量化模型:")
print(model)
output = model(torch.randn(1, 784))
print(f"输出形状: {output.shape}")
量化感知训练
在训练过程中模拟量化效果:
import torch
import torch.nn as nn
class QATModel(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.quantization.QuantStub()
self.fc1 = nn.Linear(784, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.dequant(x)
return x
model = QATModel()
model.train()
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model, inplace=True)
print("量化感知训练模型:")
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for epoch in range(5):
x = torch.randn(32, 784)
y = torch.randint(0, 10, (32,))
optimizer.zero_grad()
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
model.eval()
quantized_model = torch.quantization.convert(model)
print("\n最终量化模型:")
print(quantized_model)
TensorRT 加速
使用 Torch-TensorRT
import torch
import torch.nn as nn
import torch_tensorrt
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(128, 10)
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
model = SimpleModel().cuda().eval()
trt_model = torch_tensorrt.compile(model,
inputs=[torch_tensorrt.Input(
min_shape=[1, 3, 224, 224],
opt_shape=[8, 3, 224, 224],
max_shape=[32, 3, 224, 224],
dtype=torch.float32
)],
enabled_precisions={torch.float32, torch.float16}
)
print("TensorRT 模型编译成功!")
x = torch.randn(8, 3, 224, 224).cuda()
output = trt_model(x)
print(f"输出形状: {output.shape}")
性能对比
import torch
import time
def benchmark(model, input_shape, num_iterations=100, device='cuda'):
model = model.to(device).eval()
x = torch.randn(input_shape).to(device)
with torch.no_grad():
for _ in range(10):
model(x)
torch.cuda.synchronize() if device == 'cuda' else None
start = time.time()
for _ in range(num_iterations):
model(x)
torch.cuda.synchronize() if device == 'cuda' else None
end = time.time()
avg_time = (end - start) / num_iterations * 1000
return avg_time
print("性能对比测试:")
print(" 原始 PyTorch 模型: 需要 GPU 环境")
print(" TensorRT 模型: 需要 GPU 和 TensorRT 环境")
服务端部署
Flask API 服务
from flask import Flask, request, jsonify
import torch
import torch.nn as nn
from PIL import Image
import io
import base64
app = Flask(__name__)
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(128, 10)
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
model = SimpleModel()
model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
def preprocess_image(image_data):
image = Image.open(io.BytesIO(image_data)).convert('RGB')
image = image.resize((224, 224))
image = torch.tensor(list(image.getdata())).float()
image = image.view(224, 224, 3).permute(2, 0, 1) / 255.0
image = (image - torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)) / \
torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
return image.unsqueeze(0)
@app.route('/predict', methods=['POST'])
def predict():
try:
data = request.json
image_data = base64.b64decode(data['image'])
image = preprocess_image(image_data).to(device)
with torch.no_grad():
output = model(image)
probabilities = torch.softmax(output, dim=1)
predicted_class = output.argmax(dim=1).item()
return jsonify({
'predicted_class': predicted_class,
'probabilities': probabilities[0].tolist()
})
except Exception as e:
return jsonify({'error': str(e)}), 400
@app.route('/health', methods=['GET'])
def health():
return jsonify({'status': 'healthy'})
if __name__ == '__main__':
print("启动模型服务...")
print(" 使用设备:", device)
FastAPI 服务
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
import torch
import torch.nn as nn
from PIL import Image
import io
app = FastAPI(title="PyTorch Model API")
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(128, 10)
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
model = None
device = None
@app.on_event("startup")
async def startup_event():
global model, device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleModel().to(device)
model.load_state_dict(torch.load('model_weights.pth', map_location=device))
model.eval()
print(f"模型加载完成,使用设备: {device}")
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert('RGB')
image = image.resize((224, 224))
image_tensor = torch.tensor(list(image.getdata())).float()
image_tensor = image_tensor.view(224, 224, 3).permute(2, 0, 1) / 255.0
image_tensor = (image_tensor - torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)) / \
torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
image_tensor = image_tensor.unsqueeze(0).to(device)
with torch.no_grad():
output = model(image_tensor)
probabilities = torch.softmax(output, dim=1)
predicted_class = output.argmax(dim=1).item()
return JSONResponse({
"predicted_class": predicted_class,
"probabilities": probabilities[0].tolist()
})
@app.get("/health")
async def health():
return {"status": "healthy", "device": str(device)}
移动端部署
导出移动端模型
import torch
import torch.nn as nn
import torchvision.models as models
model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
model.eval()
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)
traced_model._save_for_lite_interpreter('mobile_model.ptl')
print("移动端模型导出成功!")
print("文件: mobile_model.ptl")
优化移动端模型
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile
model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
model.eval()
traced_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
optimized_model = optimize_for_mobile(traced_model)
optimized_model._save_for_lite_interpreter('optimized_mobile_model.ptl')
print("优化后的移动端模型已保存")
部署最佳实践
模型版本管理
import torch
import json
import os
from datetime import datetime
def save_model_with_version(model, save_dir, version, metadata=None):
os.makedirs(save_dir, exist_ok=True)
model_path = os.path.join(save_dir, f'model_v{version}.pth')
torch.save(model.state_dict(), model_path)
traced_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
traced_path = os.path.join(save_dir, f'model_v{version}.pt')
traced_model.save(traced_path)
meta = {
'version': version,
'timestamp': datetime.now().isoformat(),
'model_path': model_path,
'traced_path': traced_path,
**(metadata or {})
}
meta_path = os.path.join(save_dir, f'metadata_v{version}.json')
with open(meta_path, 'w') as f:
json.dump(meta, f, indent=2)
print(f"模型版本 {version} 已保存")
return model_path, traced_path
model = models.resnet18()
save_model_with_version(
model,
'model_versions',
version='1.0.0',
metadata={'accuracy': 0.95, 'dataset': 'CIFAR10'}
)
推理性能监控
import torch
import time
import json
from collections import defaultdict
class InferenceMonitor:
def __init__(self):
self.metrics = defaultdict(list)
def __call__(self, func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
inference_time = (end_time - start_time) * 1000
self.metrics['inference_time'].append(inference_time)
return result
return wrapper
def get_stats(self):
times = self.metrics['inference_time']
if not times:
return {}
return {
'count': len(times),
'avg_time_ms': sum(times) / len(times),
'min_time_ms': min(times),
'max_time_ms': max(times),
'p95_time_ms': sorted(times)[int(len(times) * 0.95)] if len(times) > 20 else max(times)
}
monitor = InferenceMonitor()
@monitor
def inference(model, input_tensor):
with torch.no_grad():
return model(input_tensor)
model = models.resnet18().eval()
for _ in range(100):
inference(model, torch.randn(1, 3, 224, 224))
print("推理性能统计:")
print(json.dumps(monitor.get_stats(), indent=2))
常见问题
问题 1:模型导出失败
症状:TorchScript 导出报错
解决方案:
- 检查模型是否有不支持的 Python 控制流
- 使用 script 替代 trace
- 确保所有操作都是 PyTorch 原生操作
问题 2:推理速度慢
症状:生产环境推理延迟高
解决方案:
- 使用模型量化
- 使用 TensorRT 加速
- 启用半精度推理
- 优化数据预处理
问题 3:内存占用高
症状:服务内存占用过大
解决方案:
- 使用模型量化减少内存
- 及时释放中间变量
- 使用 torch.no_grad()
- 考虑模型剪枝
小结
本章我们学习了:
- TorchScript:trace 和 script 两种导出方式
- ONNX:跨框架模型格式导出和验证
- 模型量化:动态量化、静态量化、量化感知训练
- TensorRT:NVIDIA GPU 高性能推理
- 服务端部署:Flask 和 FastAPI 服务
- 移动端部署:PyTorch Lite 模型导出
- 最佳实践:版本管理、性能监控