ONNX 模型优化
模型导出后往往包含冗余节点、未优化的计算图,直接部署会浪费计算资源。本章介绍多种 ONNX 模型优化技术,包括图简化、算子融合、量化等,帮助你榨取硬件的最大性能。
优化概述
ONNX 模型优化通常遵循以下流程:
每个阶段都有对应的工具和技术:
| 阶段 | 工具/技术 | 目标 |
|---|---|---|
| 图简化 | onnx-simplifier | 消除冗余节点、常量折叠 |
| 算子融合 | onnx-graphsurgeon | 合并相邻算子、减少内存访问 |
| 量化压缩 | ONNX Runtime Quantization | 降低精度、减少内存和计算量 |
| 推理引擎优化 | TensorRT/OpenVINO | 硬件特定优化 |
图简化:onnx-simplifier
onnx-simplifier(简称 onnxsim)是最流行的 ONNX 模型简化工具,由腾讯开源。它通过常量折叠、死代码消除等技术简化模型。
安装
pip install onnx-simplifier
基本使用
# 命令行使用
python -m onnxsim input.onnx output.onnx
# 查看帮助
python -m onnxsim --help
Python API 使用
import onnx
from onnxsim import simplify
# 加载模型
model = onnx.load("model.onnx")
# 简化模型
# skip_shape_inference=True 可以跳过形状推断(某些模型需要)
model_simplified, check = simplify(model)
# check 表示简化后的模型是否通过验证
print(f"简化成功: {check}")
# 保存简化后的模型
onnx.save(model_simplified, "model_simplified.onnx")
处理动态形状
如果模型有动态输入形状,需要提供示例输入:
import onnx
from onnxsim import simplify
import numpy as np
model = onnx.load("model.onnx")
# 提供示例输入形状
input_shapes = {
"input": [1, 3, 224, 224]
}
model_simplified, check = simplify(
model,
input_shapes=input_shapes,
# skip_fuse_bn=False # 是否跳过 BatchNorm 融合
)
onnx.save(model_simplified, "model_simplified.onnx")
onnxsim 做了什么?
onnxsim 执行以下优化:
常量折叠:将可以在编译期计算的表达式提前计算。
优化前: Constant(2) -> Constant(3) -> Mul -> ...
优化后: Constant(6) -> ...
死代码消除:移除不影响输出的节点。
优化前: input -> Conv -> Relu -> output
-> Identity -> (未使用)
优化后: input -> Conv -> Relu -> output
冗余操作消除:合并相邻的相同操作。
优化前: input -> Reshape(A) -> Reshape(B) -> ...
优化后: input -> Reshape(A*B) -> ...
验证简化效果
import onnx
def count_nodes(model_path):
model = onnx.load(model_path)
return len(model.graph.node)
print(f"原始模型节点数: {count_nodes('model.onnx')}")
print(f"简化后节点数: {count_nodes('model_simplified.onnx')}")
图手术:onnx-graphsurgeon
onnx-graphsurgeon 是 NVIDIA 提供的 ONNX 图操作工具,可以精确地修改模型结构。它比 onnxsim 更底层,适合复杂修改。
安装
pip install onnx-graphsurgeon
基本操作
import onnx_graphsurgeon as gs
import onnx
# 加载模型
graph = gs.import_onnx(onnx.load("model.onnx"))
# 查看图信息
print(f"输入: {[inp.name for inp in graph.inputs]}")
print(f"输出: {[out.name for out in graph.outputs]}")
print(f"节点数: {len(graph.nodes)}")
# 遍历节点
for node in graph.nodes:
print(f"节点: {node.name}, 类型: {node.op}")
print(f" 输入: {[i.name for i in node.inputs]}")
print(f" 输出: {[o.name for o in node.outputs]}")
删除节点
import onnx_graphsurgeon as gs
import onnx
graph = gs.import_onnx(onnx.load("model.onnx"))
# 找到并删除特定节点
for node in graph.nodes:
if node.name == "redundant_node":
# 断开节点的输出连接
node.outputs.clear()
break
# 清理孤立节点
graph.cleanup()
# 保存修改后的模型
onnx.save(gs.export_onnx(graph), "model_modified.onnx")
替换节点
import onnx_graphsurgeon as gs
import onnx
import numpy as np
graph = gs.import_onnx(onnx.load("model.onnx"))
# 创建新的常量节点
new_const = gs.Constant(
"new_constant",
values=np.array([1.0, 2.0, 3.0], dtype=np.float32)
)
# 找到旧节点并替换
for node in graph.nodes:
if node.op == "OldOp":
# 替换操作类型
node.op = "NewOp"
# 修改属性
node.attrs["new_attr"] = 42
break
graph.cleanup()
onnx.save(gs.export_onnx(graph), "model_modified.onnx")
插入节点
import onnx_graphsurgeon as gs
import onnx
graph = gs.import_onnx(onnx.load("model.onnx"))
# 假设想在某个卷积后插入 ReLU
for node in graph.nodes:
if node.name == "target_conv":
# 获取卷积的输出张量
conv_output = node.outputs[0]
# 创建 ReLU 节点
relu_output = gs.Variable("relu_output")
relu_node = gs.Node(
op="Relu",
name="inserted_relu",
inputs=[conv_output],
outputs=[relu_output]
)
# 将 ReLU 节点添加到图中
graph.nodes.append(relu_node)
# 更新下游节点的输入
for downstream in graph.nodes:
if downstream != node and conv_output in downstream.inputs:
idx = downstream.inputs.index(conv_output)
downstream.inputs[idx] = relu_output
break
graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph), "model_modified.onnx")
量化
量化是将模型从高精度(如 FP32)转换为低精度(如 FP16、INT8)的过程,可以显著减少模型大小和推理延迟。
量化的基本原理
量化将浮点数映射到整数范围:
其中:
scale是缩放因子zero_point是零点偏移
反量化是逆过程:
动态量化
动态量化在推理时动态计算激活值的量化参数,实现简单但有一定的运行时开销。
from onnxruntime.quantization import quantize_dynamic, QuantType
# 动态量化
quantize_dynamic(
model_input="model.onnx",
model_output="model_quantized.onnx",
weight_type=QuantType.QInt8, # 权重量化类型
# op_types_to_quantize=["MatMul", "Gemm"], # 可选:指定要量化的算子类型
)
print("动态量化完成")
静态量化
静态量化需要校准数据来预先计算激活值的量化参数,推理时没有额外开销,精度损失通常更小。
import os
import numpy as np
from onnxruntime.quantization import quantize_static, QuantType, CalibrationDataReader
# 准备校准数据
class ImageCalibrationDataReader(CalibrationDataReader):
def __init__(self, calibration_dir, batch_size=1):
self.calibration_dir = calibration_dir
self.batch_size = batch_size
self.data_list = self._load_data()
self.index = 0
def _load_data(self):
# 加载校准数据
# 这里用随机数据示例,实际应该加载真实数据
data_list = []
for i in range(100): # 100 张校准图片
data = np.random.randn(1, 3, 224, 224).astype(np.float32)
data_list.append(data)
return data_list
def get_next(self):
if self.index >= len(self.data_list):
return None
batch = self.data_list[self.index]
self.index += 1
return {"input": batch}
def rewind(self):
self.index = 0
# 预处理模型(推荐)
from onnxruntime.quantization.shape_inference import quant_pre_process
quant_pre_process("model.onnx", "model_preprocessed.onnx")
# 静态量化
calibration_reader = ImageCalibrationDataReader("calibration_data")
quantize_static(
model_input="model_preprocessed.onnx",
model_output="model_quantized.onnx",
calibration_data_reader=calibration_reader,
quant_format=QuantFormat.QDQ, # QDQ 格式
per_channel=False, # 是否按通道量化
weight_type=QuantType.QInt8,
)
print("静态量化完成")
量化格式选择
ONNX 支持两种量化表示格式:
QOperator 格式:使用专门的量化算子,如 QLinearConv、MatMulInteger。
QDQ 格式:在原算子前后插入 QuantizeLinear 和 DeQuantizeLinear 节点。
from onnxruntime.quantization import QuantFormat
# QDQ 格式(推荐,兼容性更好)
quantize_static(
...,
quant_format=QuantFormat.QDQ,
)
# QOperator 格式
quantize_static(
...,
quant_format=QuantFormat.QOperator,
)
量化数据类型选择
| 数据类型 | 说明 | 推荐场景 |
|---|---|---|
QInt8 | 有符号 8 位整数 | CPU 推理(默认) |
QUInt8 | 无符号 8 位整数 | 某些激活函数需要 |
QFloat16 | 16 位浮点 | GPU 推理,精度损失小 |
from onnxruntime.quantization import QuantType
# INT8 量化(CPU 推荐)
quantize_static(..., weight_type=QuantType.QInt8)
# FP16 量化(GPU 推荐)
# 需要使用不同的方法
FP16 量化
对于 GPU 推理,FP16 是更好的选择:
import onnx
from onnxconverter_common import float16
# 加载模型
model = onnx.load("model.onnx")
# 转换为 FP16
model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
# 保存
onnx.save(model_fp16, "model_fp16.onnx")
量化精度验证
量化后需要验证精度损失:
import onnxruntime as ort
import numpy as np
def compare_outputs(model_fp32, model_quant, input_data):
# FP32 推理
session_fp32 = ort.InferenceSession(model_fp32)
output_fp32 = session_fp32.run(None, {"input": input_data})[0]
# 量化模型推理
session_quant = ort.InferenceSession(model_quant)
output_quant = session_quant.run(None, {"input": input_data})[0]
# 计算差异
diff = np.abs(output_fp32 - output_quant)
print(f"最大差异: {diff.max():.6f}")
print(f"平均差异: {diff.mean():.6f}")
return diff.max()
# 测试
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
compare_outputs("model.onnx", "model_quantized.onnx", input_data)
量化调试
如果量化后精度损失过大,可以使用调试工具定位问题:
from onnxruntime.quantization.qdq_loss_debug import (
create_weight_matching,
modify_model_output_intermediate_tensors,
collect_activations,
create_activation_matching,
)
# 1. 匹配权重
weight_matching = create_weight_matching(
onnx.load("model.onnx"),
onnx.load("model_quantized.onnx")
)
# 2. 收集激活值
model_augmented = modify_model_output_intermediate_tensors(
onnx.load("model.onnx")
)
# 3. 比较激活值差异
# ... 使用 collect_activations 收集数据后比较
ONNX Runtime 图优化
ONNX Runtime 内置了三级图优化,在加载模型时自动执行:
import onnxruntime as ort
options = ort.SessionOptions()
# 基础优化:冗余节点消除、常量折叠
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
# 扩展优化:算子融合(Conv+BN、Conv+ReLU 等)
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
# 全部优化:包括布局优化
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession("model.onnx", sess_options=options)
保存优化后的模型
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
options.optimized_model_filepath = "model_optimized.onnx"
session = ort.InferenceSession("model.onnx", sess_options=options)
# 模型在加载时会被优化并保存到指定路径
性能对比与基准测试
模型大小对比
import os
def get_file_size_mb(filepath):
return os.path.getsize(filepath) / (1024 * 1024)
print(f"原始模型: {get_file_size_mb('model.onnx'):.2f} MB")
print(f"简化后: {get_file_size_mb('model_simplified.onnx'):.2f} MB")
print(f"量化后: {get_file_size_mb('model_quantized.onnx'):.2f} MB")
推理速度对比
import onnxruntime as ort
import numpy as np
import time
def benchmark(model_path, num_iterations=100):
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
# 预热
for _ in range(10):
session.run(None, {input_name: input_data})
# 计时
times = []
for _ in range(num_iterations):
start = time.perf_counter()
session.run(None, {input_name: input_data})
end = time.perf_counter()
times.append((end - start) * 1000)
return np.mean(times), np.std(times)
# 对比
mean_orig, std_orig = benchmark("model.onnx")
mean_opt, std_opt = benchmark("model_quantized.onnx")
print(f"原始模型: {mean_orig:.2f} ± {std_opt:.2f} ms")
print(f"优化后: {mean_opt:.2f} ± {std_opt:.2f} ms")
print(f"加速比: {mean_orig / mean_opt:.2f}x")
优化最佳实践
推荐的优化流程
各场景推荐方案
| 场景 | 推荐优化 | 理由 |
|---|---|---|
| 快速原型验证 | onnx-simplifier | 无精度损失,立即可用 |
| CPU 部署 | 简化 + INT8 量化 | CPU INT8 加速效果好 |
| GPU 部署 | 简化 + FP16 | GPU FP16 性能好,精度损失小 |
| 边缘设备 | 简化 + INT8 + 引擎特定优化 | 内存和计算资源有限 |
| 实时推理 | 所有优化 + TensorRT/OpenVINO | 追求极致性能 |
避免过度优化
优化不是越多越好,需要注意:
- 精度损失:量化可能带来精度下降,需要验证
- 兼容性问题:某些优化后的模型可能在特定推理引擎上不兼容
- 优化时间成本:静态量化需要校准数据,增加部署复杂度
总结
模型优化是提升推理性能的关键步骤。选择合适的优化策略需要考虑:
- 目标硬件:CPU 偏向 INT8,GPU 偏向 FP16
- 精度要求:对精度敏感的应用需要谨慎选择量化方案
- 部署复杂度:简单场景优先使用 onnx-simplifier
下一章将介绍模型验证和调试技术,确保优化后的模型仍然正确。