ONNX 知识速查表
本章提供 ONNX 开发中常用的命令、API 和配置参数的速查参考。
环境安装
# 核心库
pip install onnx onnxruntime
# GPU 推理
pip install onnxruntime-gpu
# 模型简化
pip install onnx-simplifier
# TensorFlow 转换
pip install tf2onnx
# 图操作
pip install onnx-graphsurgeon
# FP16 转换
pip install onnxconverter-common
PyTorch 导出 ONNX
基础导出
import torch
torch.onnx.export(
model, # PyTorch 模型
dummy_input, # 示例输入
"model.onnx", # 输出路径
dynamo=True, # PyTorch 2.6+ 推荐
opset_version=17, # ONNX Opset 版本
input_names=["input"], # 输入名称
output_names=["output"], # 输出名称
)
动态维度配置
# 方式一:dynamic_axes(通用)
dynamic_axes = {
"input": {0: "batch_size", 2: "height", 3: "width"},
"output": {0: "batch_size"},
}
# 方式二:dynamic_shapes(PyTorch 2.6+)
from torch.export import Dim
batch_size = Dim("batch_size", min=1, max=64)
dynamic_shapes = {"input": {0: batch_size}}
多输入模型
torch.onnx.export(
model,
(input1, input2), # 元组形式的多输入
"model.onnx",
input_names=["image", "text"],
output_names=["output"],
)
ONNX 模型验证
import onnx
import onnxruntime as ort
import numpy as np
# 检查模型结构
model = onnx.load("model.onnx")
onnx.checker.check_model(model)
# 验证输出一致性
session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
output = session.run([output_name], {input_name: input_data})[0]
ONNX Runtime Python API
创建推理会话
import onnxruntime as ort
# CPU 推理
session = ort.InferenceSession("model.onnx")
# GPU 推理
session = ort.InferenceSession(
"model.onnx",
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
# 查看可用提供者
print(ort.get_available_providers())
获取模型信息
# 输入信息
for inp in session.get_inputs():
print(f"{inp.name}: shape={inp.shape}, type={inp.type}")
# 输出信息
for out in session.get_outputs():
print(f"{out.name}: shape={out.shape}, type={out.type}")
执行推理
# 基础推理
outputs = session.run(None, {"input": input_data})
# 指定输出
outputs = session.run(["output"], {"input": input_data})
# 获取第一个输出
result = outputs[0]
SessionOptions 配置
options = ort.SessionOptions()
# 线程数
options.intra_op_num_threads = 4
options.inter_op_num_threads = 1
# 图优化级别
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# 保存优化后的模型
options.optimized_model_filepath = "optimized.onnx"
session = ort.InferenceSession("model.onnx", sess_options=options)
GPU IOBinding
import torch
# 创建 IOBinding
io_binding = session.io_binding()
# 绑定 GPU 输入
input_tensor = torch.randn(1, 3, 224, 224, device="cuda")
io_binding.bind_input(
name="input",
device_type="cuda",
device_id=0,
element_type=np.float32,
shape=tuple(input_tensor.shape),
buffer_ptr=input_tensor.data_ptr()
)
# 绑定 GPU 输出
output_tensor = torch.empty(1, 1000, device="cuda")
io_binding.bind_output(
name="output",
device_type="cuda",
device_id=0,
element_type=np.float32,
shape=tuple(output_tensor.shape),
buffer_ptr=output_tensor.data_ptr()
)
# 执行推理
session.run_with_iobinding(io_binding)
模型简化
命令行
# 基础简化
python -m onnxsim input.onnx output.onnx
# 指定输入形状
python -m onnxsim input.onnx output.onnx --input-shape input:1,3,224,224
# 跳过形状推断
python -m onnxsim input.onnx output.onnx --skip-shape-inference
Python API
from onnxsim import simplify
import onnx
model = onnx.load("model.onnx")
# 基础简化
model_simp, check = simplify(model)
# 指定输入形状
model_simp, check = simplify(
model,
input_shapes={"input": [1, 3, 224, 224]}
)
onnx.save(model_simp, "simplified.onnx")
量化
动态量化
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic(
model_input="model.onnx",
model_output="model_int8.onnx",
weight_type=QuantType.QInt8,
)
静态量化
from onnxruntime.quantization import quantize_static, QuantType, QuantFormat, CalibrationDataReader
# 实现校准数据读取器
class MyDataReader(CalibrationDataReader):
def __init__(self):
self.data = [np.random.randn(1, 3, 224, 224).astype(np.float32) for _ in range(100)]
self.index = 0
def get_next(self):
if self.index >= len(self.data):
return None
data = {"input": self.data[self.index]}
self.index += 1
return data
# 预处理
from onnxruntime.quantization.shape_inference import quant_pre_process
quant_pre_process("model.onnx", "model_preprocessed.onnx")
# 量化
quantize_static(
model_input="model_preprocessed.onnx",
model_output="model_int8.onnx",
calibration_data_reader=MyDataReader(),
quant_format=QuantFormat.QDQ,
weight_type=QuantType.QInt8,
)
FP16 转换
import onnx
from onnxconverter_common import float16
model = onnx.load("model.onnx")
model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
onnx.save(model_fp16, "model_fp16.onnx")
TensorFlow 转换 ONNX
# SavedModel 格式
python -m tf2onnx.convert --saved-model saved_model_dir --output model.onnx
# Checkpoint 格式
python -m tf2onnx.convert --checkpoint model.ckpt --output model.onnx --inputs input:0 --outputs output:0
# GraphDef 格式
python -m tf2onnx.convert --graphdef model.pb --output model.onnx --inputs input:0 --outputs output:0
# TFLite 格式
python -m tf2onnx.convert --tflite model.tflite --output model.onnx
# 指定 Opset
python -m tf2onnx.convert --saved-model saved_model_dir --output model.onnx --opset 17
Python API
import tf2onnx
# Keras 模型
model_proto, _ = tf2onnx.convert.from_keras(
keras_model,
input_signature=[tf.TensorSpec([None, 224, 224, 3], tf.float32)],
opset=17,
output_path="model.onnx"
)
# TensorFlow 函数
model_proto, _ = tf2onnx.convert.from_function(
tf_function,
input_signature=[tf.TensorSpec([None, 224, 224, 3], tf.float32)],
output_path="model.onnx"
)
图操作 (onnx-graphsurgeon)
import onnx_graphsurgeon as gs
import onnx
# 加载图
graph = gs.import_onnx(onnx.load("model.onnx"))
# 遍历节点
for node in graph.nodes:
print(f"{node.name}: {node.op}")
# 查找节点
conv_nodes = [n for n in graph.nodes if n.op == "Conv"]
# 修改节点属性
for node in graph.nodes:
if node.op == "Conv":
node.attrs["kernel_shape"] = [3, 3]
# 保存
graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph), "modified.onnx")
C++ API 速查
头文件和链接
# CMakeLists.txt
include_directories(${ONNXRUNTIME_DIR}/include)
link_directories(${ONNXRUNTIME_DIR}/lib)
target_link_libraries(your_target onnxruntime)
基础推理
#include <onnxruntime_cxx_api.h>
#include <vector>
// 创建环境
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "app");
// 创建会话
Ort::SessionOptions session_options;
Ort::Session session(env, "model.onnx", session_options);
// 创建输入张量
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
std::vector<int64_t> shape = {1, 3, 224, 224};
std::vector<float> input_data(1 * 3 * 224 * 224, 0.5f);
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
memory_info, input_data.data(), input_data.size(),
shape.data(), shape.size()
);
// 执行推理
const char* input_names[] = {"input"};
const char* output_names[] = {"output"};
auto outputs = session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, 1);
// 获取输出
float* output_data = outputs[0].GetTensorMutableData<float>();
常见问题排查
输出不一致
# 检查 eval 模式
model.eval()
# 检查数据类型
input_data = input_data.astype(np.float32)
# 检查动态维度
print(session.get_inputs()[0].shape)
算子不支持
# 查看模型算子
import onnx
model = onnx.load("model.onnx")
ops = set(node.op_type for node in model.graph.node)
print(ops)
# 检查 Opset 版本
for opset in model.opset_import:
print(f"{opset.domain}: {opset.version}")
GPU 推理失败
# 检查 CUDA 提供者
print(ort.get_available_providers())
# 显式指定
session = ort.InferenceSession("model.onnx", providers=['CPUExecutionProvider'])
性能基准测试模板
import time
import numpy as np
import onnxruntime as ort
def benchmark(model_path, input_shape=(1, 3, 224, 224), iterations=100, warmup=10):
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
input_data = np.random.randn(*input_shape).astype(np.float32)
# 预热
for _ in range(warmup):
session.run(None, {input_name: input_data})
# 计时
times = []
for _ in range(iterations):
start = time.perf_counter()
session.run(None, {input_name: input_data})
times.append((time.perf_counter() - start) * 1000)
times = np.array(times)
print(f"模型: {model_path}")
print(f" 平均延迟: {times.mean():.2f} ms")
print(f" P99 延迟: {np.percentile(times, 99):.2f} ms")
print(f" 吞吐量: {1000 / times.mean() * input_shape[0]:.1f} samples/s")
benchmark("model.onnx")
模型信息查看
import onnx
model = onnx.load("model.onnx")
# IR 版本
print(f"IR 版本: {model.ir_version}")
# Opset 版本
for opset in model.opset_import:
print(f"Opset: {opset.domain or 'default'} v{opset.version}")
# 生产者信息
print(f"生产者: {model.producer_name} {model.producer_version}")
# 输入
for inp in model.graph.input:
print(f"输入: {inp.name}")
# 输出
for out in model.graph.output:
print(f"输出: {out.name}")
# 节点数
print(f"节点数: {len(model.graph.node)}")