跳到主要内容

PyTorch 导出 ONNX

将 PyTorch 模型导出为 ONNX 格式是模型部署的第一步,也是关键一步。PyTorch 提供了 torch.onnx.export() 函数来完成这个转换。理解这个函数的工作原理和参数配置,对于成功导出模型至关重要。

导出原理:Tracing vs Scripting

PyTorch 导出 ONNX 有两种主要方式,理解它们的区别是正确导出模型的前提。

Tracing(追踪模式)

Tracing 是默认的导出方式。它的工作原理是:让模型执行一次前向传播,记录所有张量操作,然后将这些操作转换为 ONNX 算子。

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 5)

def forward(self, x):
return self.fc(x)

model = SimpleModel().eval()
dummy_input = torch.randn(1, 10)

# Tracing 模式导出
torch.onnx.export(model, dummy_input, "model.onnx")

Tracing 的优点是实现简单,几乎所有模型都能用这种方式导出。但它的局限在于:只记录实际执行的路径。如果模型包含条件分支(如 if 语句),Tracing 只会记录那次执行时走的分支。

class ConditionalModel(nn.Module):
def forward(self, x, flag):
if flag: # 条件分支
return x * 2
else:
return x + 1

model = ConditionalModel()
dummy_x = torch.randn(1, 10)
dummy_flag = torch.tensor(True) # 只会记录 flag=True 的分支

torch.onnx.export(model, (dummy_x, dummy_flag), "model.onnx")
# 导出的模型永远执行 x * 2,无论推理时 flag 是什么值

Scripting(脚本模式)

Scripting 通过分析 Python 源代码来构建计算图,能够保留条件分支和循环结构。但它的限制更多:模型代码必须使用 TorchScript 支持的 Python 子集。

# Scripting 模式导出
scripted_model = torch.jit.script(model)
torch.onnx.export(scripted_model, (dummy_x, dummy_flag), "model.onnx")

如何选择?

模型特点推荐方式
纯张量计算,无控制流Tracing(默认)
包含数据依赖的条件分支Scripting 或重构模型
包含循环(如 RNN)Scripting
复杂动态行为考虑拆分模型或使用自定义导出逻辑

实际经验表明,大多数视觉模型(CNN、Transformer)都可以用 Tracing 成功导出。遇到条件分支时,一个常见的做法是将分支逻辑移到模型外部,让模型只做纯粹的张量计算。

torch.onnx.export() 参数详解

PyTorch 2.6 以后,torch.onnx.export() 引入了新的基于 torch.export 的导出器(通过 dynamo=True 参数启用)。这是目前推荐的方式。

基础参数

torch.onnx.export(
model, # 要导出的模型
args, # 模型输入(元组形式)
f, # 输出文件路径
*,
# 核心参数
dynamo=True, # 使用新的导出器(PyTorch 2.6+ 推荐)
opset_version=None, # ONNX 算子集版本
input_names=None, # 输入节点名称列表
output_names=None, # 输出节点名称列表
dynamic_axes=None, # 动态维度配置
dynamic_shapes=None, # dynamo 模式下的动态形状(推荐)
export_params=True, # 是否导出权重
)

完整示例:CNN 模型导出

import torch
import torch.nn as nn

class ConvNet(nn.Module):
"""一个典型的 CNN 分类模型"""
def __init__(self, num_classes=10):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 8 * 8, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(128, num_classes),
)

def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x

# 创建模型并加载权重(这里用随机权重示例)
model = ConvNet(num_classes=10)
model.eval() # 重要:切换到推理模式

# 构造示例输入
# batch_size=1, channels=3, height=32, width=32
dummy_input = torch.randn(1, 3, 32, 32)

# 导出模型
torch.onnx.export(
model,
dummy_input,
"convnet.onnx",
dynamo=True,
opset_version=17,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size"}, # 第 0 维(batch)是动态的
"output": {0: "batch_size"},
},
verbose=False,
)

print("模型导出成功!")

动态维度配置

动态维度是导出时最关键的配置之一。它决定了模型是否能够处理不同大小的输入。

为什么需要动态维度?

默认情况下,导出的 ONNX 模型会固定输入张量的形状。如果你用 batch_size=1 的示例输入导出模型,那么推理时只能接受 batch_size=1 的输入。

# 错误示例:没有配置动态维度
torch.onnx.export(model, torch.randn(1, 3, 32, 32), "model.onnx")

# 推理时会失败
# session.run(..., {"input": np.random.randn(4, 3, 32, 32)}) # batch_size=4 不匹配!

dynamic_axes 参数详解

dynamic_axes 参数告诉 ONNX:哪些维度是可以变化的。

# 方式一:字典形式(推荐,可自定义维度名称)
dynamic_axes = {
"input": {0: "batch_size", 2: "height", 3: "width"},
"output": {0: "batch_size"},
}

# 方式二:列表形式(自动生成名称)
dynamic_axes = {
"input": [0, 2, 3], # 第 0、2、3 维是动态的
"output": [0],
}

torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=["input"],
output_names=["output"],
dynamic_axes=dynamic_axes,
)

dynamic_shapes 参数(PyTorch 2.6+)

dynamo=True 模式下,推荐使用 dynamic_shapes 参数。它更灵活,可以表达复杂的动态形状约束。

from torch.export import Dim

# 定义动态维度
batch_size = Dim("batch_size", min=1, max=64) # 范围限制
height = Dim("height", min=16)
width = Dim("width", min=16)

# 等价关系:height 和 width 总是相等
# width = Dim("width", min=16, max=256)

torch.onnx.export(
model,
dummy_input,
"model.onnx",
dynamo=True,
dynamic_shapes={
"input": {0: batch_size, 2: height, 3: width},
},
)

多输入模型的动态维度

class MultiInputModel(nn.Module):
def forward(self, image, text_embedding):
# image: [batch, channels, H, W]
# text_embedding: [batch, seq_len, dim]
return image.mean(dim=[1, 2, 3]) + text_embedding.mean(dim=1)

model = MultiInputModel().eval()

dummy_image = torch.randn(1, 3, 224, 224)
dummy_text = torch.randn(1, 10, 512)

torch.onnx.export(
model,
(dummy_image, dummy_text), # 多个输入用元组
"model.onnx",
input_names=["image", "text"],
output_names=["output"],
dynamic_axes={
"image": {0: "batch", 2: "height", 3: "width"},
"text": {0: "batch", 1: "seq_len"},
"output": {0: "batch"},
},
)

Opset 版本选择

Opset 版本决定了导出模型使用的算子集合。不同版本有不同的特点:

# 查看当前 PyTorch 支持的最高 Opset
import torch
print(f"默认 Opset 版本: {torch.onnx.producer_version}") # 实际使用 torch.onnx.export 时会自动选择

# 指定 Opset 版本
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=17)

Opset 版本选择指南

Opset推荐场景注意事项
11最大兼容性某些新算子不支持
14-15通用选择平衡兼容性和功能
17Transformer 模型支持 LayerNormalization
18+最新特性确保目标推理引擎支持

实用建议:先用默认版本导出,如果推理引擎报错,再根据错误信息调整。

导出后验证

导出后立即验证是必不可少的步骤。验证的核心是确保 ONNX 模型与原始 PyTorch 模型的输出一致。

基本验证

import onnx
import onnxruntime as ort
import numpy as np

# 1. 检查 ONNX 模型结构完整性
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
print("✓ ONNX 模型结构检查通过")

# 2. 创建 ONNX Runtime 推理会话
session = ort.InferenceSession("model.onnx", providers=['CPUExecutionProvider'])

# 3. 获取输入输出信息
print("输入信息:")
for inp in session.get_inputs():
print(f" 名称: {inp.name}, 形状: {inp.shape}, 类型: {inp.type}")

print("输出信息:")
for out in session.get_outputs():
print(f" 名称: {out.name}, 形状: {out.shape}, 类型: {out.type}")

# 4. 对比 PyTorch 和 ONNX Runtime 的输出
test_input = np.random.randn(1, 3, 32, 32).astype(np.float32)

# PyTorch 推理
with torch.no_grad():
pytorch_output = model(torch.from_numpy(test_input)).numpy()

# ONNX Runtime 推理
input_name = session.get_inputs()[0].name
onnx_output = session.run(None, {input_name: test_input})[0]

# 比较输出
max_diff = np.abs(pytorch_output - onnx_output).max()
print(f"最大输出差异: {max_diff}")

if max_diff < 1e-5:
print("✓ 输出一致性验证通过")
else:
print("✗ 警告:输出存在显著差异,请检查导出配置")

验证动态维度

# 测试不同的 batch size
for batch_size in [1, 4, 8]:
test_input = np.random.randn(batch_size, 3, 32, 32).astype(np.float32)

# PyTorch
with torch.no_grad():
pt_out = model(torch.from_numpy(test_input)).numpy()

# ONNX
onnx_out = session.run(None, {input_name: test_input})[0]

diff = np.abs(pt_out - onnx_out).max()
print(f"batch_size={batch_size}: 最大差异 = {diff:.6f}")

常见问题与解决方案

问题一:模型不在 eval 模式

导出前必须调用 model.eval(),否则 BatchNorm、Dropout 等层的行为会导致输出不一致。

# 错误
model = MyModel() # 默认是 train 模式
torch.onnx.export(model, ...) # 导出的模型行为可能不正确

# 正确
model = MyModel().eval() # 切换到推理模式
torch.onnx.export(model, ...)

问题二:数据类型不匹配

PyTorch 默认使用 float32,但某些操作可能产生 float64。ONNX 对类型要求严格。

# 确保输入类型正确
dummy_input = torch.randn(1, 10, dtype=torch.float32) # 显式指定 float32

# 如果模型内部有类型转换,检查导出后的模型
model = onnx.load("model.onnx")
for node in model.graph.node:
print(f"{node.op_type}: {[i for i in node.input]}")

问题三:不支持的操作

某些 PyTorch 操作在 ONNX 中没有直接对应的算子。

# 错误示例
class UnsupportedModel(nn.Module):
def forward(self, x):
return torch.unique(x) # unique 算子支持有限

# 解决方案:在导出前用基础算子重新实现
class FixedModel(nn.Module):
def forward(self, x):
# 用其他方式实现相同功能
sorted_x, indices = torch.sort(x.flatten())
# ... 自定义实现
return sorted_x

问题四:大型模型导出

超过 2GB 的模型需要使用外部数据格式。

torch.onnx.export(
large_model,
dummy_input,
"large_model.onnx",
dynamo=True,
external_data=True, # 将权重存储在外部文件
)
# 会生成 large_model.onnx 和 large_model.onnx.data

进阶技巧

自定义算子导出

当模型使用了 ONNX 标准算子集之外的操作时,需要注册自定义符号函数。

# 注册自定义算子的导出规则
from torch.onnx import register_custom_op_symbolic

def my_custom_op_symbolic(g, input, dim):
# 定义如何将自定义操作转换为 ONNX 算子
return g.op("MyDomain::MyCustomOp", input, dim_i=dim)

register_custom_op_symbolic("my_namespace::custom_op", my_custom_op_symbolic, 1)

# 然后正常导出
torch.onnx.export(model_with_custom_op, dummy_input, "model.onnx")

导出模型元数据

添加模型描述信息有助于后续管理和追踪。

torch.onnx.export(
model,
dummy_input,
"model.onnx",
dynamo=True,
# 模型描述会嵌入到 ONNX 文件中
)

# 添加元数据
import onnx
onnx_model = onnx.load("model.onnx")
onnx_model.model_doc_string = "ResNet-50 for image classification"
onnx_model.producer_name = "MyTrainingPipeline"
onnx_model.producer_version = "1.0.0"
onnx.save(onnx_model, "model.onnx")

下一步

成功导出 ONNX 模型后,接下来的工作:

  1. 模型简化:使用 onnx-simplifier 减少冗余节点
  2. 性能优化:进行量化、算子融合等优化
  3. 部署推理:在 Python 或 C++ 中加载运行模型

下一章将介绍如何使用 ONNX Runtime 进行高效的 Python 推理。