跳到主要内容

ONNX 模型优化

模型导出后往往包含冗余节点、未优化的计算图,直接部署会浪费计算资源。本章介绍多种 ONNX 模型优化技术,包括图简化、算子融合、量化等,帮助你榨取硬件的最大性能。

优化概述

ONNX 模型优化通常遵循以下流程:

每个阶段都有对应的工具和技术:

阶段工具/技术目标
图简化onnx-simplifier消除冗余节点、常量折叠
算子融合onnx-graphsurgeon合并相邻算子、减少内存访问
量化压缩ONNX Runtime Quantization降低精度、减少内存和计算量
推理引擎优化TensorRT/OpenVINO硬件特定优化

图简化:onnx-simplifier

onnx-simplifier(简称 onnxsim)是最流行的 ONNX 模型简化工具,由腾讯开源。它通过常量折叠、死代码消除等技术简化模型。

安装

pip install onnx-simplifier

基本使用

# 命令行使用
python -m onnxsim input.onnx output.onnx

# 查看帮助
python -m onnxsim --help

Python API 使用

import onnx
from onnxsim import simplify

# 加载模型
model = onnx.load("model.onnx")

# 简化模型
# skip_shape_inference=True 可以跳过形状推断(某些模型需要)
model_simplified, check = simplify(model)

# check 表示简化后的模型是否通过验证
print(f"简化成功: {check}")

# 保存简化后的模型
onnx.save(model_simplified, "model_simplified.onnx")

处理动态形状

如果模型有动态输入形状,需要提供示例输入:

import onnx
from onnxsim import simplify
import numpy as np

model = onnx.load("model.onnx")

# 提供示例输入形状
input_shapes = {
"input": [1, 3, 224, 224]
}

model_simplified, check = simplify(
model,
input_shapes=input_shapes,
# skip_fuse_bn=False # 是否跳过 BatchNorm 融合
)

onnx.save(model_simplified, "model_simplified.onnx")

onnxsim 做了什么?

onnxsim 执行以下优化:

常量折叠:将可以在编译期计算的表达式提前计算。

优化前: Constant(2) -> Constant(3) -> Mul -> ...
优化后: Constant(6) -> ...

死代码消除:移除不影响输出的节点。

优化前: input -> Conv -> Relu -> output
-> Identity -> (未使用)
优化后: input -> Conv -> Relu -> output

冗余操作消除:合并相邻的相同操作。

优化前: input -> Reshape(A) -> Reshape(B) -> ...
优化后: input -> Reshape(A*B) -> ...

验证简化效果

import onnx

def count_nodes(model_path):
model = onnx.load(model_path)
return len(model.graph.node)

print(f"原始模型节点数: {count_nodes('model.onnx')}")
print(f"简化后节点数: {count_nodes('model_simplified.onnx')}")

图手术:onnx-graphsurgeon

onnx-graphsurgeon 是 NVIDIA 提供的 ONNX 图操作工具,可以精确地修改模型结构。它比 onnxsim 更底层,适合复杂修改。

安装

pip install onnx-graphsurgeon

基本操作

import onnx_graphsurgeon as gs
import onnx

# 加载模型
graph = gs.import_onnx(onnx.load("model.onnx"))

# 查看图信息
print(f"输入: {[inp.name for inp in graph.inputs]}")
print(f"输出: {[out.name for out in graph.outputs]}")
print(f"节点数: {len(graph.nodes)}")

# 遍历节点
for node in graph.nodes:
print(f"节点: {node.name}, 类型: {node.op}")
print(f" 输入: {[i.name for i in node.inputs]}")
print(f" 输出: {[o.name for o in node.outputs]}")

删除节点

import onnx_graphsurgeon as gs
import onnx

graph = gs.import_onnx(onnx.load("model.onnx"))

# 找到并删除特定节点
for node in graph.nodes:
if node.name == "redundant_node":
# 断开节点的输出连接
node.outputs.clear()
break

# 清理孤立节点
graph.cleanup()

# 保存修改后的模型
onnx.save(gs.export_onnx(graph), "model_modified.onnx")

替换节点

import onnx_graphsurgeon as gs
import onnx
import numpy as np

graph = gs.import_onnx(onnx.load("model.onnx"))

# 创建新的常量节点
new_const = gs.Constant(
"new_constant",
values=np.array([1.0, 2.0, 3.0], dtype=np.float32)
)

# 找到旧节点并替换
for node in graph.nodes:
if node.op == "OldOp":
# 替换操作类型
node.op = "NewOp"
# 修改属性
node.attrs["new_attr"] = 42
break

graph.cleanup()
onnx.save(gs.export_onnx(graph), "model_modified.onnx")

插入节点

import onnx_graphsurgeon as gs
import onnx

graph = gs.import_onnx(onnx.load("model.onnx"))

# 假设想在某个卷积后插入 ReLU
for node in graph.nodes:
if node.name == "target_conv":
# 获取卷积的输出张量
conv_output = node.outputs[0]

# 创建 ReLU 节点
relu_output = gs.Variable("relu_output")
relu_node = gs.Node(
op="Relu",
name="inserted_relu",
inputs=[conv_output],
outputs=[relu_output]
)

# 将 ReLU 节点添加到图中
graph.nodes.append(relu_node)

# 更新下游节点的输入
for downstream in graph.nodes:
if downstream != node and conv_output in downstream.inputs:
idx = downstream.inputs.index(conv_output)
downstream.inputs[idx] = relu_output

break

graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph), "model_modified.onnx")

量化

量化是将模型从高精度(如 FP32)转换为低精度(如 FP16、INT8)的过程,可以显著减少模型大小和推理延迟。

量化的基本原理

量化将浮点数映射到整数范围:

xquantized=round(xfloatscale)+zero_pointx_{quantized} = \text{round}\left(\frac{x_{float}}{scale}\right) + zero\_point

其中:

  • scale 是缩放因子
  • zero_point 是零点偏移

反量化是逆过程:

xfloat=scale×(xquantizedzero_point)x_{float} = scale \times (x_{quantized} - zero\_point)

动态量化

动态量化在推理时动态计算激活值的量化参数,实现简单但有一定的运行时开销。

from onnxruntime.quantization import quantize_dynamic, QuantType

# 动态量化
quantize_dynamic(
model_input="model.onnx",
model_output="model_quantized.onnx",
weight_type=QuantType.QInt8, # 权重量化类型
# op_types_to_quantize=["MatMul", "Gemm"], # 可选:指定要量化的算子类型
)

print("动态量化完成")

静态量化

静态量化需要校准数据来预先计算激活值的量化参数,推理时没有额外开销,精度损失通常更小。

import os
import numpy as np
from onnxruntime.quantization import quantize_static, QuantType, CalibrationDataReader

# 准备校准数据
class ImageCalibrationDataReader(CalibrationDataReader):
def __init__(self, calibration_dir, batch_size=1):
self.calibration_dir = calibration_dir
self.batch_size = batch_size
self.data_list = self._load_data()
self.index = 0

def _load_data(self):
# 加载校准数据
# 这里用随机数据示例,实际应该加载真实数据
data_list = []
for i in range(100): # 100 张校准图片
data = np.random.randn(1, 3, 224, 224).astype(np.float32)
data_list.append(data)
return data_list

def get_next(self):
if self.index >= len(self.data_list):
return None

batch = self.data_list[self.index]
self.index += 1
return {"input": batch}

def rewind(self):
self.index = 0

# 预处理模型(推荐)
from onnxruntime.quantization.shape_inference import quant_pre_process
quant_pre_process("model.onnx", "model_preprocessed.onnx")

# 静态量化
calibration_reader = ImageCalibrationDataReader("calibration_data")

quantize_static(
model_input="model_preprocessed.onnx",
model_output="model_quantized.onnx",
calibration_data_reader=calibration_reader,
quant_format=QuantFormat.QDQ, # QDQ 格式
per_channel=False, # 是否按通道量化
weight_type=QuantType.QInt8,
)

print("静态量化完成")

量化格式选择

ONNX 支持两种量化表示格式:

QOperator 格式:使用专门的量化算子,如 QLinearConvMatMulInteger

QDQ 格式:在原算子前后插入 QuantizeLinearDeQuantizeLinear 节点。

from onnxruntime.quantization import QuantFormat

# QDQ 格式(推荐,兼容性更好)
quantize_static(
...,
quant_format=QuantFormat.QDQ,
)

# QOperator 格式
quantize_static(
...,
quant_format=QuantFormat.QOperator,
)

量化数据类型选择

数据类型说明推荐场景
QInt8有符号 8 位整数CPU 推理(默认)
QUInt8无符号 8 位整数某些激活函数需要
QFloat1616 位浮点GPU 推理,精度损失小
from onnxruntime.quantization import QuantType

# INT8 量化(CPU 推荐)
quantize_static(..., weight_type=QuantType.QInt8)

# FP16 量化(GPU 推荐)
# 需要使用不同的方法

FP16 量化

对于 GPU 推理,FP16 是更好的选择:

import onnx
from onnxconverter_common import float16

# 加载模型
model = onnx.load("model.onnx")

# 转换为 FP16
model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)

# 保存
onnx.save(model_fp16, "model_fp16.onnx")

量化精度验证

量化后需要验证精度损失:

import onnxruntime as ort
import numpy as np

def compare_outputs(model_fp32, model_quant, input_data):
# FP32 推理
session_fp32 = ort.InferenceSession(model_fp32)
output_fp32 = session_fp32.run(None, {"input": input_data})[0]

# 量化模型推理
session_quant = ort.InferenceSession(model_quant)
output_quant = session_quant.run(None, {"input": input_data})[0]

# 计算差异
diff = np.abs(output_fp32 - output_quant)
print(f"最大差异: {diff.max():.6f}")
print(f"平均差异: {diff.mean():.6f}")

return diff.max()

# 测试
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
compare_outputs("model.onnx", "model_quantized.onnx", input_data)

量化调试

如果量化后精度损失过大,可以使用调试工具定位问题:

from onnxruntime.quantization.qdq_loss_debug import (
create_weight_matching,
modify_model_output_intermediate_tensors,
collect_activations,
create_activation_matching,
)

# 1. 匹配权重
weight_matching = create_weight_matching(
onnx.load("model.onnx"),
onnx.load("model_quantized.onnx")
)

# 2. 收集激活值
model_augmented = modify_model_output_intermediate_tensors(
onnx.load("model.onnx")
)

# 3. 比较激活值差异
# ... 使用 collect_activations 收集数据后比较

ONNX Runtime 图优化

ONNX Runtime 内置了三级图优化,在加载模型时自动执行:

import onnxruntime as ort

options = ort.SessionOptions()

# 基础优化:冗余节点消除、常量折叠
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC

# 扩展优化:算子融合(Conv+BN、Conv+ReLU 等)
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED

# 全部优化:包括布局优化
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

session = ort.InferenceSession("model.onnx", sess_options=options)

保存优化后的模型

options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
options.optimized_model_filepath = "model_optimized.onnx"

session = ort.InferenceSession("model.onnx", sess_options=options)
# 模型在加载时会被优化并保存到指定路径

性能对比与基准测试

模型大小对比

import os

def get_file_size_mb(filepath):
return os.path.getsize(filepath) / (1024 * 1024)

print(f"原始模型: {get_file_size_mb('model.onnx'):.2f} MB")
print(f"简化后: {get_file_size_mb('model_simplified.onnx'):.2f} MB")
print(f"量化后: {get_file_size_mb('model_quantized.onnx'):.2f} MB")

推理速度对比

import onnxruntime as ort
import numpy as np
import time

def benchmark(model_path, num_iterations=100):
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

# 预热
for _ in range(10):
session.run(None, {input_name: input_data})

# 计时
times = []
for _ in range(num_iterations):
start = time.perf_counter()
session.run(None, {input_name: input_data})
end = time.perf_counter()
times.append((end - start) * 1000)

return np.mean(times), np.std(times)

# 对比
mean_orig, std_orig = benchmark("model.onnx")
mean_opt, std_opt = benchmark("model_quantized.onnx")

print(f"原始模型: {mean_orig:.2f} ± {std_opt:.2f} ms")
print(f"优化后: {mean_opt:.2f} ± {std_opt:.2f} ms")
print(f"加速比: {mean_orig / mean_opt:.2f}x")

优化最佳实践

推荐的优化流程

各场景推荐方案

场景推荐优化理由
快速原型验证onnx-simplifier无精度损失,立即可用
CPU 部署简化 + INT8 量化CPU INT8 加速效果好
GPU 部署简化 + FP16GPU FP16 性能好,精度损失小
边缘设备简化 + INT8 + 引擎特定优化内存和计算资源有限
实时推理所有优化 + TensorRT/OpenVINO追求极致性能

避免过度优化

优化不是越多越好,需要注意:

  1. 精度损失:量化可能带来精度下降,需要验证
  2. 兼容性问题:某些优化后的模型可能在特定推理引擎上不兼容
  3. 优化时间成本:静态量化需要校准数据,增加部署复杂度

总结

模型优化是提升推理性能的关键步骤。选择合适的优化策略需要考虑:

  • 目标硬件:CPU 偏向 INT8,GPU 偏向 FP16
  • 精度要求:对精度敏感的应用需要谨慎选择量化方案
  • 部署复杂度:简单场景优先使用 onnx-simplifier

下一章将介绍模型验证和调试技术,确保优化后的模型仍然正确。