PyTorch 导出 ONNX
将 PyTorch 模型导出为 ONNX 格式是模型部署的第一步,也是关键一步。PyTorch 提供了 torch.onnx.export() 函数来完成这个转换。理解这个函数的工作原理和参数配置,对于成功导出模型至关重要。
导出原理:Tracing vs Scripting
PyTorch 导出 ONNX 有两种主要方式,理解它们的区别是正确导出模型的前提。
Tracing(追踪模式)
Tracing 是默认的导出方式。它的工作原理是:让模型执行一次前向传播,记录所有张量操作,然后将这些操作转换为 ONNX 算子。
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
model = SimpleModel().eval()
dummy_input = torch.randn(1, 10)
# Tracing 模式导出
torch.onnx.export(model, dummy_input, "model.onnx")
Tracing 的优点是实现简单,几乎所有模型都能用这种方式导出。但它的局限在于:只记录实际执行的路径。如果模型包含条件分支(如 if 语句),Tracing 只会记录那次执行时走的分支。
class ConditionalModel(nn.Module):
def forward(self, x, flag):
if flag: # 条件分支
return x * 2
else:
return x + 1
model = ConditionalModel()
dummy_x = torch.randn(1, 10)
dummy_flag = torch.tensor(True) # 只会记录 flag=True 的分支
torch.onnx.export(model, (dummy_x, dummy_flag), "model.onnx")
# 导出的模型永远执行 x * 2,无论推理时 flag 是什么值
Scripting(脚本模式)
Scripting 通过分析 Python 源代码来构建计算图,能够保留条件分支和循环结构。但它的限制更多:模型代码必须使用 TorchScript 支持的 Python 子集。
# Scripting 模式导出
scripted_model = torch.jit.script(model)
torch.onnx.export(scripted_model, (dummy_x, dummy_flag), "model.onnx")
如何选择?
| 模型特点 | 推荐方式 |
|---|---|
| 纯张量计算,无控制流 | Tracing(默认) |
| 包含数据依赖的条件分支 | Scripting 或重构模型 |
| 包含循环(如 RNN) | Scripting |
| 复杂动态行为 | 考虑拆分模型或使用自定义导出逻辑 |
实际经验表明,大多数视觉模型(CNN、Transformer)都可以用 Tracing 成功导出。遇到条件分支时,一个常见的做法是将分支逻辑移到模型外部,让模型只做纯粹的张量计算。
torch.onnx.export() 参数详解
PyTorch 2.6 以后,torch.onnx.export() 引入了新的基于 torch.export 的导出器(通过 dynamo=True 参数启用)。这是目前推荐的方式。
基础参数
torch.onnx.export(
model, # 要导出的模型
args, # 模型输入(元组形式)
f, # 输出文件路径
*,
# 核心参数
dynamo=True, # 使用新的导出器(PyTorch 2.6+ 推荐)
opset_version=None, # ONNX 算子集版本
input_names=None, # 输入节点名称列表
output_names=None, # 输出节点名称列表
dynamic_axes=None, # 动态维度配置
dynamic_shapes=None, # dynamo 模式下的动态形状(推荐)
export_params=True, # 是否导出权重
)
完整示例:CNN 模型导出
import torch
import torch.nn as nn
class ConvNet(nn.Module):
"""一个典型的 CNN 分类模型"""
def __init__(self, num_classes=10):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 8 * 8, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(128, num_classes),
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
# 创建模型并加载权重(这里用随机权重示例)
model = ConvNet(num_classes=10)
model.eval() # 重要:切换到推理模式
# 构造示例输入
# batch_size=1, channels=3, height=32, width=32
dummy_input = torch.randn(1, 3, 32, 32)
# 导出模型
torch.onnx.export(
model,
dummy_input,
"convnet.onnx",
dynamo=True,
opset_version=17,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size"}, # 第 0 维(batch)是动态的
"output": {0: "batch_size"},
},
verbose=False,
)
print("模型导出成功!")
动态维度配置
动态维度是导出时最关键的配置之一。它决定了模型是否能够处理不同大小的输入。
为什么需要动态维度?
默认情况下,导出的 ONNX 模型会固定输入张量的形状。如果你用 batch_size=1 的示例输入导出模型,那么推理时只能接受 batch_size=1 的输入。
# 错误示例:没有配置动态维度
torch.onnx.export(model, torch.randn(1, 3, 32, 32), "model.onnx")
# 推理时会失败
# session.run(..., {"input": np.random.randn(4, 3, 32, 32)}) # batch_size=4 不匹配!
dynamic_axes 参数详解
dynamic_axes 参数告诉 ONNX:哪些维度是可以变化的。
# 方式一:字典形式(推荐,可自定义维度名称)
dynamic_axes = {
"input": {0: "batch_size", 2: "height", 3: "width"},
"output": {0: "batch_size"},
}
# 方式二:列表形式(自动生成名称)
dynamic_axes = {
"input": [0, 2, 3], # 第 0、2、3 维是动态的
"output": [0],
}
torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=["input"],
output_names=["output"],
dynamic_axes=dynamic_axes,
)
dynamic_shapes 参数(PyTorch 2.6+)
在 dynamo=True 模式下,推荐使用 dynamic_shapes 参数。它更灵活,可以表达复杂的动态形状约束。
from torch.export import Dim
# 定义动态维度
batch_size = Dim("batch_size", min=1, max=64) # 范围限制
height = Dim("height", min=16)
width = Dim("width", min=16)
# 等价关系:height 和 width 总是相等
# width = Dim("width", min=16, max=256)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
dynamo=True,
dynamic_shapes={
"input": {0: batch_size, 2: height, 3: width},
},
)
多输入模型的动态维度
class MultiInputModel(nn.Module):
def forward(self, image, text_embedding):
# image: [batch, channels, H, W]
# text_embedding: [batch, seq_len, dim]
return image.mean(dim=[1, 2, 3]) + text_embedding.mean(dim=1)
model = MultiInputModel().eval()
dummy_image = torch.randn(1, 3, 224, 224)
dummy_text = torch.randn(1, 10, 512)
torch.onnx.export(
model,
(dummy_image, dummy_text), # 多个输入用元组
"model.onnx",
input_names=["image", "text"],
output_names=["output"],
dynamic_axes={
"image": {0: "batch", 2: "height", 3: "width"},
"text": {0: "batch", 1: "seq_len"},
"output": {0: "batch"},
},
)
Opset 版本选择
Opset 版本决定了导出模型使用的算子集合。不同版本有不同的特点:
# 查看当前 PyTorch 支持的最高 Opset
import torch
print(f"默认 Opset 版本: {torch.onnx.producer_version}") # 实际使用 torch.onnx.export 时会自动选择
# 指定 Opset 版本
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=17)
Opset 版本选择指南
| Opset | 推荐场景 | 注意事项 |
|---|---|---|
| 11 | 最大兼容性 | 某些新算子不支持 |
| 14-15 | 通用选择 | 平衡兼容性和功能 |
| 17 | Transformer 模型 | 支持 LayerNormalization |
| 18+ | 最新特性 | 确保目标推理引擎支持 |
实用建议:先用默认版本导出,如果推理引擎报错,再根据错误信息调整。
导出后验证
导出后立即验证是必不可少的步骤。验证的核心是确保 ONNX 模型与原始 PyTorch 模型的输出一致。
基本验证
import onnx
import onnxruntime as ort
import numpy as np
# 1. 检查 ONNX 模型结构完整性
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
print("✓ ONNX 模型结构检查通过")
# 2. 创建 ONNX Runtime 推理会话
session = ort.InferenceSession("model.onnx", providers=['CPUExecutionProvider'])
# 3. 获取输入输出信息
print("输入信息:")
for inp in session.get_inputs():
print(f" 名称: {inp.name}, 形状: {inp.shape}, 类型: {inp.type}")
print("输出信息:")
for out in session.get_outputs():
print(f" 名称: {out.name}, 形状: {out.shape}, 类型: {out.type}")
# 4. 对比 PyTorch 和 ONNX Runtime 的输出
test_input = np.random.randn(1, 3, 32, 32).astype(np.float32)
# PyTorch 推理
with torch.no_grad():
pytorch_output = model(torch.from_numpy(test_input)).numpy()
# ONNX Runtime 推理
input_name = session.get_inputs()[0].name
onnx_output = session.run(None, {input_name: test_input})[0]
# 比较输出
max_diff = np.abs(pytorch_output - onnx_output).max()
print(f"最大输出差异: {max_diff}")
if max_diff < 1e-5:
print("✓ 输出一致性验证通过")
else:
print("✗ 警告:输出存在显著差异,请检查导出配置")
验证动态维度
# 测试不同的 batch size
for batch_size in [1, 4, 8]:
test_input = np.random.randn(batch_size, 3, 32, 32).astype(np.float32)
# PyTorch
with torch.no_grad():
pt_out = model(torch.from_numpy(test_input)).numpy()
# ONNX
onnx_out = session.run(None, {input_name: test_input})[0]
diff = np.abs(pt_out - onnx_out).max()
print(f"batch_size={batch_size}: 最大差异 = {diff:.6f}")
常见问题与解决方案
问题一:模型不在 eval 模式
导出前必须调用 model.eval(),否则 BatchNorm、Dropout 等层的行为会导致输出不一致。
# 错误
model = MyModel() # 默认是 train 模式
torch.onnx.export(model, ...) # 导出的模型行为可能不正确
# 正确
model = MyModel().eval() # 切换到推理模式
torch.onnx.export(model, ...)
问题二:数据类型不匹配
PyTorch 默认使用 float32,但某些操作可能产生 float64。ONNX 对类型要求严格。
# 确保输入类型正确
dummy_input = torch.randn(1, 10, dtype=torch.float32) # 显式指定 float32
# 如果模型内部有类型转换,检查导出后的模型
model = onnx.load("model.onnx")
for node in model.graph.node:
print(f"{node.op_type}: {[i for i in node.input]}")
问题三:不支持的操作
某些 PyTorch 操作在 ONNX 中没有直接对应的算子。
# 错误示例
class UnsupportedModel(nn.Module):
def forward(self, x):
return torch.unique(x) # unique 算子支持有限
# 解决方案:在导出前用基础算子重新实现
class FixedModel(nn.Module):
def forward(self, x):
# 用其他方式实现相同功能
sorted_x, indices = torch.sort(x.flatten())
# ... 自定义实现
return sorted_x
问题四:大型模型导出
超过 2GB 的模型需要使用外部数据格式。
torch.onnx.export(
large_model,
dummy_input,
"large_model.onnx",
dynamo=True,
external_data=True, # 将权重存储在外部文件
)
# 会生成 large_model.onnx 和 large_model.onnx.data
进阶技巧
自定义算子导出
当模型使用了 ONNX 标准算子集之外的操作时,需要注册自定义符号函数。
# 注册自定义算子的导出规则
from torch.onnx import register_custom_op_symbolic
def my_custom_op_symbolic(g, input, dim):
# 定义如何将自定义操作转换为 ONNX 算子
return g.op("MyDomain::MyCustomOp", input, dim_i=dim)
register_custom_op_symbolic("my_namespace::custom_op", my_custom_op_symbolic, 1)
# 然后正常导出
torch.onnx.export(model_with_custom_op, dummy_input, "model.onnx")
导出模型元数据
添加模型描述信息有助于后续管理和追踪。
torch.onnx.export(
model,
dummy_input,
"model.onnx",
dynamo=True,
# 模型描述会嵌入到 ONNX 文件中
)
# 添加元数据
import onnx
onnx_model = onnx.load("model.onnx")
onnx_model.model_doc_string = "ResNet-50 for image classification"
onnx_model.producer_name = "MyTrainingPipeline"
onnx_model.producer_version = "1.0.0"
onnx.save(onnx_model, "model.onnx")
下一步
成功导出 ONNX 模型后,接下来的工作:
- 模型简化:使用 onnx-simplifier 减少冗余节点
- 性能优化:进行量化、算子融合等优化
- 部署推理:在 Python 或 C++ 中加载运行模型
下一章将介绍如何使用 ONNX Runtime 进行高效的 Python 推理。