跳到主要内容

ONNX 知识速查表

本章提供 ONNX 开发中常用的命令、API 和配置参数的速查参考。

环境安装

# 核心库
pip install onnx onnxruntime

# GPU 推理
pip install onnxruntime-gpu

# 模型简化
pip install onnx-simplifier

# TensorFlow 转换
pip install tf2onnx

# 图操作
pip install onnx-graphsurgeon

# FP16 转换
pip install onnxconverter-common

PyTorch 导出 ONNX

基础导出

import torch

torch.onnx.export(
model, # PyTorch 模型
dummy_input, # 示例输入
"model.onnx", # 输出路径
dynamo=True, # PyTorch 2.6+ 推荐
opset_version=17, # ONNX Opset 版本
input_names=["input"], # 输入名称
output_names=["output"], # 输出名称
)

动态维度配置

# 方式一:dynamic_axes(通用)
dynamic_axes = {
"input": {0: "batch_size", 2: "height", 3: "width"},
"output": {0: "batch_size"},
}

# 方式二:dynamic_shapes(PyTorch 2.6+)
from torch.export import Dim
batch_size = Dim("batch_size", min=1, max=64)
dynamic_shapes = {"input": {0: batch_size}}

多输入模型

torch.onnx.export(
model,
(input1, input2), # 元组形式的多输入
"model.onnx",
input_names=["image", "text"],
output_names=["output"],
)

ONNX 模型验证

import onnx
import onnxruntime as ort
import numpy as np

# 检查模型结构
model = onnx.load("model.onnx")
onnx.checker.check_model(model)

# 验证输出一致性
session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
output = session.run([output_name], {input_name: input_data})[0]

ONNX Runtime Python API

创建推理会话

import onnxruntime as ort

# CPU 推理
session = ort.InferenceSession("model.onnx")

# GPU 推理
session = ort.InferenceSession(
"model.onnx",
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)

# 查看可用提供者
print(ort.get_available_providers())

获取模型信息

# 输入信息
for inp in session.get_inputs():
print(f"{inp.name}: shape={inp.shape}, type={inp.type}")

# 输出信息
for out in session.get_outputs():
print(f"{out.name}: shape={out.shape}, type={out.type}")

执行推理

# 基础推理
outputs = session.run(None, {"input": input_data})

# 指定输出
outputs = session.run(["output"], {"input": input_data})

# 获取第一个输出
result = outputs[0]

SessionOptions 配置

options = ort.SessionOptions()

# 线程数
options.intra_op_num_threads = 4
options.inter_op_num_threads = 1

# 图优化级别
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

# 保存优化后的模型
options.optimized_model_filepath = "optimized.onnx"

session = ort.InferenceSession("model.onnx", sess_options=options)

GPU IOBinding

import torch

# 创建 IOBinding
io_binding = session.io_binding()

# 绑定 GPU 输入
input_tensor = torch.randn(1, 3, 224, 224, device="cuda")
io_binding.bind_input(
name="input",
device_type="cuda",
device_id=0,
element_type=np.float32,
shape=tuple(input_tensor.shape),
buffer_ptr=input_tensor.data_ptr()
)

# 绑定 GPU 输出
output_tensor = torch.empty(1, 1000, device="cuda")
io_binding.bind_output(
name="output",
device_type="cuda",
device_id=0,
element_type=np.float32,
shape=tuple(output_tensor.shape),
buffer_ptr=output_tensor.data_ptr()
)

# 执行推理
session.run_with_iobinding(io_binding)

模型简化

命令行

# 基础简化
python -m onnxsim input.onnx output.onnx

# 指定输入形状
python -m onnxsim input.onnx output.onnx --input-shape input:1,3,224,224

# 跳过形状推断
python -m onnxsim input.onnx output.onnx --skip-shape-inference

Python API

from onnxsim import simplify
import onnx

model = onnx.load("model.onnx")

# 基础简化
model_simp, check = simplify(model)

# 指定输入形状
model_simp, check = simplify(
model,
input_shapes={"input": [1, 3, 224, 224]}
)

onnx.save(model_simp, "simplified.onnx")

量化

动态量化

from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
model_input="model.onnx",
model_output="model_int8.onnx",
weight_type=QuantType.QInt8,
)

静态量化

from onnxruntime.quantization import quantize_static, QuantType, QuantFormat, CalibrationDataReader

# 实现校准数据读取器
class MyDataReader(CalibrationDataReader):
def __init__(self):
self.data = [np.random.randn(1, 3, 224, 224).astype(np.float32) for _ in range(100)]
self.index = 0

def get_next(self):
if self.index >= len(self.data):
return None
data = {"input": self.data[self.index]}
self.index += 1
return data

# 预处理
from onnxruntime.quantization.shape_inference import quant_pre_process
quant_pre_process("model.onnx", "model_preprocessed.onnx")

# 量化
quantize_static(
model_input="model_preprocessed.onnx",
model_output="model_int8.onnx",
calibration_data_reader=MyDataReader(),
quant_format=QuantFormat.QDQ,
weight_type=QuantType.QInt8,
)

FP16 转换

import onnx
from onnxconverter_common import float16

model = onnx.load("model.onnx")
model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
onnx.save(model_fp16, "model_fp16.onnx")

TensorFlow 转换 ONNX

# SavedModel 格式
python -m tf2onnx.convert --saved-model saved_model_dir --output model.onnx

# Checkpoint 格式
python -m tf2onnx.convert --checkpoint model.ckpt --output model.onnx --inputs input:0 --outputs output:0

# GraphDef 格式
python -m tf2onnx.convert --graphdef model.pb --output model.onnx --inputs input:0 --outputs output:0

# TFLite 格式
python -m tf2onnx.convert --tflite model.tflite --output model.onnx

# 指定 Opset
python -m tf2onnx.convert --saved-model saved_model_dir --output model.onnx --opset 17

Python API

import tf2onnx

# Keras 模型
model_proto, _ = tf2onnx.convert.from_keras(
keras_model,
input_signature=[tf.TensorSpec([None, 224, 224, 3], tf.float32)],
opset=17,
output_path="model.onnx"
)

# TensorFlow 函数
model_proto, _ = tf2onnx.convert.from_function(
tf_function,
input_signature=[tf.TensorSpec([None, 224, 224, 3], tf.float32)],
output_path="model.onnx"
)

图操作 (onnx-graphsurgeon)

import onnx_graphsurgeon as gs
import onnx

# 加载图
graph = gs.import_onnx(onnx.load("model.onnx"))

# 遍历节点
for node in graph.nodes:
print(f"{node.name}: {node.op}")

# 查找节点
conv_nodes = [n for n in graph.nodes if n.op == "Conv"]

# 修改节点属性
for node in graph.nodes:
if node.op == "Conv":
node.attrs["kernel_shape"] = [3, 3]

# 保存
graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph), "modified.onnx")

C++ API 速查

头文件和链接

# CMakeLists.txt
include_directories(${ONNXRUNTIME_DIR}/include)
link_directories(${ONNXRUNTIME_DIR}/lib)
target_link_libraries(your_target onnxruntime)

基础推理

#include <onnxruntime_cxx_api.h>
#include <vector>

// 创建环境
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "app");

// 创建会话
Ort::SessionOptions session_options;
Ort::Session session(env, "model.onnx", session_options);

// 创建输入张量
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
std::vector<int64_t> shape = {1, 3, 224, 224};
std::vector<float> input_data(1 * 3 * 224 * 224, 0.5f);

Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
memory_info, input_data.data(), input_data.size(),
shape.data(), shape.size()
);

// 执行推理
const char* input_names[] = {"input"};
const char* output_names[] = {"output"};
auto outputs = session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, 1);

// 获取输出
float* output_data = outputs[0].GetTensorMutableData<float>();

常见问题排查

输出不一致

# 检查 eval 模式
model.eval()

# 检查数据类型
input_data = input_data.astype(np.float32)

# 检查动态维度
print(session.get_inputs()[0].shape)

算子不支持

# 查看模型算子
import onnx
model = onnx.load("model.onnx")
ops = set(node.op_type for node in model.graph.node)
print(ops)

# 检查 Opset 版本
for opset in model.opset_import:
print(f"{opset.domain}: {opset.version}")

GPU 推理失败

# 检查 CUDA 提供者
print(ort.get_available_providers())

# 显式指定
session = ort.InferenceSession("model.onnx", providers=['CPUExecutionProvider'])

性能基准测试模板

import time
import numpy as np
import onnxruntime as ort

def benchmark(model_path, input_shape=(1, 3, 224, 224), iterations=100, warmup=10):
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
input_data = np.random.randn(*input_shape).astype(np.float32)

# 预热
for _ in range(warmup):
session.run(None, {input_name: input_data})

# 计时
times = []
for _ in range(iterations):
start = time.perf_counter()
session.run(None, {input_name: input_data})
times.append((time.perf_counter() - start) * 1000)

times = np.array(times)
print(f"模型: {model_path}")
print(f" 平均延迟: {times.mean():.2f} ms")
print(f" P99 延迟: {np.percentile(times, 99):.2f} ms")
print(f" 吞吐量: {1000 / times.mean() * input_shape[0]:.1f} samples/s")

benchmark("model.onnx")

模型信息查看

import onnx

model = onnx.load("model.onnx")

# IR 版本
print(f"IR 版本: {model.ir_version}")

# Opset 版本
for opset in model.opset_import:
print(f"Opset: {opset.domain or 'default'} v{opset.version}")

# 生产者信息
print(f"生产者: {model.producer_name} {model.producer_version}")

# 输入
for inp in model.graph.input:
print(f"输入: {inp.name}")

# 输出
for out in model.graph.output:
print(f"输出: {out.name}")

# 节点数
print(f"节点数: {len(model.graph.node)}")