跳到主要内容

模型部署

训练好的深度学习模型最终需要部署到生产环境中。本章将介绍 PyTorch 模型的各种部署方式,包括 TorchScript、ONNX、TensorRT 以及移动端部署。

部署概述

部署方式对比

部署方式优点缺点适用场景
PyTorch 原生简单直接依赖 Python研究、原型开发
TorchScript无需 Python调试困难服务端部署
ONNX跨框架兼容性问题跨平台部署
TensorRT高性能仅限 NVIDIA GPU高性能推理
移动端部署轻量级功能受限移动应用

部署流程

训练模型 → 导出模型 → 优化模型 → 部署服务
↓ ↓ ↓ ↓
PyTorch TorchScript 量化/剪枝 服务端/移动端
ONNX TensorRT

TorchScript 部署

使用 torch.jit.trace

trace 通过记录实际执行的操作来创建模型:

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(128, 10)

def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x

model = SimpleModel()
model.eval()

example_input = torch.randn(1, 3, 224, 224)

traced_model = torch.jit.trace(model, example_input)

print("Traced 模型结构:")
print(traced_model.graph)

traced_model.save('traced_model.pt')

loaded_model = torch.jit.load('traced_model.pt')

output = loaded_model(example_input)
print(f"输出形状: {output.shape}")

使用 torch.jit.script

script 通过解析 Python 代码来创建模型,支持控制流:

import torch
import torch.nn as nn

class ScriptableModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()

def forward(self, x: torch.Tensor, use_relu: bool = True) -> torch.Tensor:
x = self.fc1(x)
if use_relu:
x = self.relu(x)
x = self.fc2(x)
return x

model = ScriptableModel(784, 256, 10)
model.eval()

scripted_model = torch.jit.script(model)

print("Scripted 模型结构:")
print(scripted_model.code)

scripted_model.save('scripted_model.pt')

loaded_model = torch.jit.load('scripted_model.pt')

x = torch.randn(1, 784)
output1 = loaded_model(x, True)
output2 = loaded_model(x, False)
print(f"使用 ReLU 输出: {output1.shape}")
print(f"不使用 ReLU 输出: {output2.shape}")

trace 与 script 的选择

特性tracescript
控制流不支持支持
动态形状有限支持支持
使用难度简单中等
调试困难较容易
适用场景固定计算图动态计算图

混合使用 trace 和 script

import torch
import torch.nn as nn

class FeatureExtractor(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
self.relu = nn.ReLU()
self.pool = nn.AdaptiveAvgPool2d(1)

def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.pool(x)
return x.view(x.size(0), -1)

class Classifier(nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.fc = nn.Linear(num_features, num_classes)

def forward(self, x, return_probs: bool = True):
logits = self.fc(x)
if return_probs:
return torch.softmax(logits, dim=-1)
return logits

class HybridModel(nn.Module):
def __init__(self):
super().__init__()
self.features = torch.jit.trace(FeatureExtractor(), torch.randn(1, 3, 224, 224))
self.classifier = torch.jit.script(Classifier(128, 10))

def forward(self, x, return_probs: bool = True):
features = self.features(x)
return self.classifier(features, return_probs)

model = HybridModel()
model.eval()

scripted_model = torch.jit.script(model)
scripted_model.save('hybrid_model.pt')

loaded = torch.jit.load('hybrid_model.pt')
output = loaded(torch.randn(1, 3, 224, 224), True)
print(f"输出形状: {output.shape}")

ONNX 导出

基本导出

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(128, 10)

def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x

model = SimpleModel()
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
model,
dummy_input,
'model.onnx',
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)

print("ONNX 模型导出成功!")

验证 ONNX 模型

import onnx

onnx_model = onnx.load('model.onnx')
onnx.checker.check_model(onnx_model)
print("ONNX 模型验证通过!")

print("\n模型输入:")
for input in onnx_model.graph.input:
print(f" 名称: {input.name}")
print(f" 类型: {input.type}")

print("\n模型输出:")
for output in onnx_model.graph.output:
print(f" 名称: {output.name}")

ONNX Runtime 推理

import onnxruntime as ort
import numpy as np
import torch

ort_session = ort.InferenceSession('model.onnx')

print("ONNX Runtime 输入信息:")
for input in ort_session.get_inputs():
print(f" 名称: {input.name}, 形状: {input.shape}, 类型: {input.type}")

print("\nONNX Runtime 输出信息:")
for output in ort_session.get_outputs():
print(f" 名称: {output.name}, 形状: {output.shape}, 类型: {output.type}")

input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

ort_inputs = {ort_session.get_inputs()[0].name: input_data}
ort_outputs = ort_session.run(None, ort_inputs)

print(f"\nONNX Runtime 输出形状: {ort_outputs[0].shape}")

model = SimpleModel()
model.eval()
torch_output = model(torch.from_numpy(input_data))

print(f"PyTorch 输出形状: {torch_output.shape}")

diff = np.abs(ort_outputs[0] - torch_output.detach().numpy()).max()
print(f"最大差异: {diff}")

模型量化

动态量化

动态量化在推理时动态地将权重从浮点数量化为整数:

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 10)

def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x

model = SimpleModel()
model.eval()

quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.Linear},
dtype=torch.qint8
)

print("原始模型:")
print(model)

print("\n量化模型:")
print(quantized_model)

def count_parameters(model):
return sum(p.numel() for p in model.parameters())

print(f"\n原始模型参数量: {count_parameters(model):,}")

x = torch.randn(1, 784)
output = quantized_model(x)
print(f"量化模型输出形状: {output.shape}")

静态量化

静态量化需要校准数据来确定量化参数:

import torch
import torch.nn as nn

class QuantizableModel(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.quantization.QuantStub()
self.fc1 = nn.Linear(784, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
self.dequant = torch.quantization.DeQuantStub()

def forward(self, x):
x = self.quant(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.dequant(x)
return x

model = QuantizableModel()
model.eval()

model.qconfig = torch.quantization.get_default_qconfig('fbgemm')

torch.quantization.prepare(model, inplace=True)

print("校准模型...")
for _ in range(100):
model(torch.randn(1, 784))

torch.quantization.convert(model, inplace=True)

print("静态量化模型:")
print(model)

output = model(torch.randn(1, 784))
print(f"输出形状: {output.shape}")

量化感知训练

在训练过程中模拟量化效果:

import torch
import torch.nn as nn

class QATModel(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.quantization.QuantStub()
self.fc1 = nn.Linear(784, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
self.dequant = torch.quantization.DeQuantStub()

def forward(self, x):
x = self.quant(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.dequant(x)
return x

model = QATModel()
model.train()

model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

torch.quantization.prepare_qat(model, inplace=True)

print("量化感知训练模型:")
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(5):
x = torch.randn(32, 784)
y = torch.randint(0, 10, (32,))

optimizer.zero_grad()
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()

print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

model.eval()
quantized_model = torch.quantization.convert(model)

print("\n最终量化模型:")
print(quantized_model)

TensorRT 加速

使用 Torch-TensorRT

import torch
import torch.nn as nn

import torch_tensorrt

class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(128, 10)

def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x

model = SimpleModel().cuda().eval()

trt_model = torch_tensorrt.compile(model,
inputs=[torch_tensorrt.Input(
min_shape=[1, 3, 224, 224],
opt_shape=[8, 3, 224, 224],
max_shape=[32, 3, 224, 224],
dtype=torch.float32
)],
enabled_precisions={torch.float32, torch.float16}
)

print("TensorRT 模型编译成功!")

x = torch.randn(8, 3, 224, 224).cuda()
output = trt_model(x)
print(f"输出形状: {output.shape}")

性能对比

import torch
import time

def benchmark(model, input_shape, num_iterations=100, device='cuda'):
model = model.to(device).eval()
x = torch.randn(input_shape).to(device)

with torch.no_grad():
for _ in range(10):
model(x)

torch.cuda.synchronize() if device == 'cuda' else None
start = time.time()
for _ in range(num_iterations):
model(x)
torch.cuda.synchronize() if device == 'cuda' else None
end = time.time()

avg_time = (end - start) / num_iterations * 1000
return avg_time

print("性能对比测试:")
print(" 原始 PyTorch 模型: 需要 GPU 环境")
print(" TensorRT 模型: 需要 GPU 和 TensorRT 环境")

服务端部署

Flask API 服务

from flask import Flask, request, jsonify
import torch
import torch.nn as nn
from PIL import Image
import io
import base64

app = Flask(__name__)

class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(128, 10)

def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x

model = SimpleModel()
model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))
model.eval()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

def preprocess_image(image_data):
image = Image.open(io.BytesIO(image_data)).convert('RGB')
image = image.resize((224, 224))
image = torch.tensor(list(image.getdata())).float()
image = image.view(224, 224, 3).permute(2, 0, 1) / 255.0
image = (image - torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)) / \
torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
return image.unsqueeze(0)

@app.route('/predict', methods=['POST'])
def predict():
try:
data = request.json
image_data = base64.b64decode(data['image'])

image = preprocess_image(image_data).to(device)

with torch.no_grad():
output = model(image)
probabilities = torch.softmax(output, dim=1)
predicted_class = output.argmax(dim=1).item()

return jsonify({
'predicted_class': predicted_class,
'probabilities': probabilities[0].tolist()
})

except Exception as e:
return jsonify({'error': str(e)}), 400

@app.route('/health', methods=['GET'])
def health():
return jsonify({'status': 'healthy'})

if __name__ == '__main__':
print("启动模型服务...")
print(" 使用设备:", device)

FastAPI 服务

from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
import torch
import torch.nn as nn
from PIL import Image
import io

app = FastAPI(title="PyTorch Model API")

class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(128, 10)

def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x

model = None
device = None

@app.on_event("startup")
async def startup_event():
global model, device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleModel().to(device)
model.load_state_dict(torch.load('model_weights.pth', map_location=device))
model.eval()
print(f"模型加载完成,使用设备: {device}")

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert('RGB')
image = image.resize((224, 224))

image_tensor = torch.tensor(list(image.getdata())).float()
image_tensor = image_tensor.view(224, 224, 3).permute(2, 0, 1) / 255.0
image_tensor = (image_tensor - torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)) / \
torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
image_tensor = image_tensor.unsqueeze(0).to(device)

with torch.no_grad():
output = model(image_tensor)
probabilities = torch.softmax(output, dim=1)
predicted_class = output.argmax(dim=1).item()

return JSONResponse({
"predicted_class": predicted_class,
"probabilities": probabilities[0].tolist()
})

@app.get("/health")
async def health():
return {"status": "healthy", "device": str(device)}

移动端部署

导出移动端模型

import torch
import torch.nn as nn
import torchvision.models as models

model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
model.eval()

example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)

traced_model._save_for_lite_interpreter('mobile_model.ptl')

print("移动端模型导出成功!")
print("文件: mobile_model.ptl")

优化移动端模型

import torch
from torch.utils.mobile_optimizer import optimize_for_mobile

model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
model.eval()

traced_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224))

optimized_model = optimize_for_mobile(traced_model)

optimized_model._save_for_lite_interpreter('optimized_mobile_model.ptl')

print("优化后的移动端模型已保存")

部署最佳实践

模型版本管理

import torch
import json
import os
from datetime import datetime

def save_model_with_version(model, save_dir, version, metadata=None):
os.makedirs(save_dir, exist_ok=True)

model_path = os.path.join(save_dir, f'model_v{version}.pth')
torch.save(model.state_dict(), model_path)

traced_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
traced_path = os.path.join(save_dir, f'model_v{version}.pt')
traced_model.save(traced_path)

meta = {
'version': version,
'timestamp': datetime.now().isoformat(),
'model_path': model_path,
'traced_path': traced_path,
**(metadata or {})
}

meta_path = os.path.join(save_dir, f'metadata_v{version}.json')
with open(meta_path, 'w') as f:
json.dump(meta, f, indent=2)

print(f"模型版本 {version} 已保存")
return model_path, traced_path

model = models.resnet18()
save_model_with_version(
model,
'model_versions',
version='1.0.0',
metadata={'accuracy': 0.95, 'dataset': 'CIFAR10'}
)

推理性能监控

import torch
import time
import json
from collections import defaultdict

class InferenceMonitor:
def __init__(self):
self.metrics = defaultdict(list)

def __call__(self, func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()

inference_time = (end_time - start_time) * 1000
self.metrics['inference_time'].append(inference_time)

return result
return wrapper

def get_stats(self):
times = self.metrics['inference_time']
if not times:
return {}

return {
'count': len(times),
'avg_time_ms': sum(times) / len(times),
'min_time_ms': min(times),
'max_time_ms': max(times),
'p95_time_ms': sorted(times)[int(len(times) * 0.95)] if len(times) > 20 else max(times)
}

monitor = InferenceMonitor()

@monitor
def inference(model, input_tensor):
with torch.no_grad():
return model(input_tensor)

model = models.resnet18().eval()
for _ in range(100):
inference(model, torch.randn(1, 3, 224, 224))

print("推理性能统计:")
print(json.dumps(monitor.get_stats(), indent=2))

常见问题

问题 1:模型导出失败

症状:TorchScript 导出报错

解决方案

  • 检查模型是否有不支持的 Python 控制流
  • 使用 script 替代 trace
  • 确保所有操作都是 PyTorch 原生操作

问题 2:推理速度慢

症状:生产环境推理延迟高

解决方案

  • 使用模型量化
  • 使用 TensorRT 加速
  • 启用半精度推理
  • 优化数据预处理

问题 3:内存占用高

症状:服务内存占用过大

解决方案

  • 使用模型量化减少内存
  • 及时释放中间变量
  • 使用 torch.no_grad()
  • 考虑模型剪枝

小结

本章我们学习了:

  1. TorchScript:trace 和 script 两种导出方式
  2. ONNX:跨框架模型格式导出和验证
  3. 模型量化:动态量化、静态量化、量化感知训练
  4. TensorRT:NVIDIA GPU 高性能推理
  5. 服务端部署:Flask 和 FastAPI 服务
  6. 移动端部署:PyTorch Lite 模型导出
  7. 最佳实践:版本管理、性能监控

参考资源