跳到主要内容

ONNX 算子集

ONNX 的核心是一系列标准化的算子(Operators),每个算子定义了特定的计算操作。理解 ONNX 算子集对于解决导出和部署问题至关重要。

什么是算子?

算子是计算图中的基本单元,相当于深度学习框架中的"层"(Layer)。例如,PyTorch 中的 nn.Conv2d 在 ONNX 中对应 Conv 算子,nn.Linear 对应 Gemm 算子。

一个 ONNX 模型本质上是一系列算子节点的有向无环图(DAG),数据沿着图的边流动,经过各个算子的计算,最终产生输出。

Opset 版本机制

ONNX 通过 Opset(Operator Set)版本机制来管理算子的演进。每个 Opset 版本定义了一组算子的规范,包括输入、输出、属性以及计算语义。

版本兼容性规则

ONNX 采用前向兼容策略:使用 Opset 12 导出的模型可以在支持 Opset 15 的推理引擎上运行。但反过来不行——使用了 Opset 15 新特性的模型无法在只支持 Opset 12 的引擎上运行。

查看模型的 Opset

import onnx

model = onnx.load("model.onnx")

print("模型的 IR 版本:", model.ir_version)
print("\nOpset 导入:")

for opset in model.opset_import:
print(f" 域: {opset.domain if opset.domain else 'ai.onnx (默认)'}")
print(f" 版本: {opset.version}")

Opset 版本选择建议

场景推荐 Opset理由
最大兼容性11-12覆盖最广的硬件和推理引擎
通用选择14-15平衡兼容性和功能
Transformer 模型17+支持 LayerNormalization 等算子
最新特性18-21某些新算子,需确认推理引擎支持

常用算子详解

数学运算类

ONNX 算子PyTorch 对应说明
Add+, torch.add()逐元素加法,支持广播
Sub-, torch.sub()逐元素减法
Mul*, torch.mul()逐元素乘法
Div/, torch.div()逐元素除法
MatMultorch.matmul()矩阵乘法
Gemmnn.Linear广义矩阵乘法,融合偏置
Powtorch.pow()幂运算
Sqrttorch.sqrt()平方根
Exptorch.exp()指数函数
Logtorch.log()自然对数

神经网络层

ONNX 算子PyTorch 对应说明
Convnn.Conv1d/2d/3d卷积操作
ConvTransposenn.ConvTranspose2d转置卷积(反卷积)
BatchNormalizationnn.BatchNorm2d批归一化
InstanceNormalizationnn.InstanceNorm2d实例归一化
LayerNormalizationnn.LayerNorm层归一化(Opset 17+)
Dropoutnn.Dropout随机失活(推理时通常被移除)
Flattentorch.flatten()展平张量
Resizenn.Upsample上采样/插值

激活函数

ONNX 算子PyTorch 对应说明
Relunn.ReLU(), F.relu()整流线性单元
LeakyRelunn.LeakyReLU()带泄漏的 ReLU
PRelunn.PReLU()参数化 ReLU
Sigmoidnn.Sigmoid(), torch.sigmoid()Sigmoid 函数
Tanhnn.Tanh(), torch.tanh()双曲正切
Softmaxnn.Softmax(), F.softmax()Softmax 函数
Gelunn.GELU(), F.gelu()高斯误差线性单元

张量操作

ONNX 算子PyTorch 对应说明
Reshapetensor.view(), tensor.reshape()改变张量形状
Transposetensor.transpose()转置
Permutetensor.permute()维度重排
Squeezetensor.squeeze()移除大小为 1 的维度
Unsqueezetensor.unsqueeze()插入大小为 1 的维度
Concattorch.cat()沿指定维度拼接
Splittorch.split()沿指定维度分割
Slicetensor[...]切片操作
Gathertorch.gather()按索引收集
ScatterNDtensor.scatter_()按索引散射

池化和归约

ONNX 算子PyTorch 对应说明
MaxPoolnn.MaxPool2d最大池化
AveragePoolnn.AvgPool2d平均池化
GlobalAveragePool自适应平均池化全局平均池化
ReduceMeantorch.mean()沿维度求均值
ReduceSumtorch.sum()沿维度求和
ReduceMaxtorch.max()沿维度求最大值
ReduceMintorch.min()沿维度求最小值
ArgMaxtorch.argmax()返回最大值索引
ArgMintorch.argmin()返回最小值索引

算子版本演变

同一个算子在不同 Opset 版本中可能有不同的行为或属性。了解这些变化有助于排查兼容性问题。

Conv 算子的演变

# Opset 1: 基础版本
# Opset 11: 添加了 dilations 属性的默认值处理
# Opset 13: 支持更多数据类型

Resize 算子的重大变化

Resize 算子在 Opset 11 和 Opset 13 之间有重大变化,这是最常遇到的问题之一:

# Opset 11 及以前:使用 scales 参数
# scales = [1.0, 1.0, 2.0, 2.0] # 各维度的缩放因子

# Opset 13 及以后:使用 sizes 参数
# sizes = [1, 1, 448, 448] # 输出的具体大小

# 这导致使用 PyTorch 导出的模型(通常用 Opset 11+)在某些推理引擎上出错

解决方案:确保推理引擎支持的 Opset 版本与模型一致,或在导出时指定合适的 Opset。

Softmax 算子的轴属性

Opset 13 对 Softmax 的轴属性处理进行了改进:

# Opset 11: axis 属性是必需的
# Opset 13: axis 有默认值 -1(最后一个维度)

# 导出时注意:
torch.onnx.export(
model,
dummy_input,
"model.onnx",
opset_version=13, # 使用新版 Opset
# ...
)

自定义算子

当模型使用了 ONNX 标准算子集中不存在的操作时,需要处理自定义算子。

识别不支持的算子

导出时如果遇到错误:

RuntimeError: ONNX export failed on an operator: my_custom_op

首先检查该操作是否可以用标准算子组合实现:

# 假设自定义操作是某种特殊的激活函数
class CustomActivation(nn.Module):
def forward(self, x):
# 尝试用标准操作实现
return torch.where(x > 0, x, 0.1 * x) # LeakyReLU 等价形式

注册自定义算子导出规则

如果确实需要自定义算子,可以注册导出规则:

from torch.onnx import register_custom_op_symbolic

def my_op_symbolic(g, input, param):
"""定义自定义操作的 ONNX 导出规则"""
# g 是 ONNX 图构建器
# 返回一个或多个 ONNX 节点
return g.op("MyDomain::MyCustomOp", input, param_i=param)

# 注册符号函数
register_custom_op_symbolic("my_namespace::my_op", my_op_symbolic, 1)

在推理引擎中实现自定义算子

导出后,还需要在目标推理引擎中实现自定义算子的计算逻辑:

ONNX Runtime Python

import onnxruntime as ort

# 自定义算子的实现
def my_custom_op_impl(x, param):
# 实际计算逻辑
return x * param

# 注册到 ONNX Runtime
# 注意:需要使用 ONNX Runtime 的扩展机制

ONNX Runtime C++

// 需要实现自定义算子内核并注册
// 参考 ONNX Runtime 文档中的 Custom Operator 部分

算子兼容性检查

使用 ONNX Checker

import onnx

# 加载并检查模型
model = onnx.load("model.onnx")

try:
onnx.checker.check_model(model)
print("模型检查通过")
except onnx.checker.ValidationError as e:
print(f"模型检查失败: {e}")

使用 ONNX Runtime 验证

import onnxruntime as ort

try:
session = ort.InferenceSession("model.onnx")
print("推理引擎可以加载模型")
except Exception as e:
print(f"加载失败: {e}")

查看算子列表

import onnx

model = onnx.load("model.onnx")

# 收集所有使用的算子
operators = set()
for node in model.graph.node:
operators.add(node.op_type)

print("模型使用的算子:")
for op in sorted(operators):
print(f" - {op}")

常见算子问题与解决

问题一:Resize/Upsample 兼容性

现象:在 TensorRT 或某些推理引擎上报错,提示 Resize 算子不支持。

原因:Resize 算子在不同 Opset 版本中有不同的参数格式。

解决

# 导出时使用较低 Opset,或检查推理引擎支持的版本
torch.onnx.export(
model,
dummy_input,
"model.onnx",
opset_version=11, # 使用兼容的版本
)

问题二:Gather/ScatterND 不支持

现象:目标设备(如某些 NPU)不支持 Gather 或 ScatterND 算子。

解决:将这些操作移到模型外部,在预处理或后处理中实现。

# 原模型
class ModelWithGather(nn.Module):
def forward(self, x, indices):
return x.gather(1, indices)

# 修改后:将 gather 移到外部
class ModelWithoutGather(nn.Module):
def forward(self, x):
return x # 只做基础计算

# 推理时在外部执行 gather
# outputs = model(inputs)
# result = outputs.gather(1, indices)

问题三:动态形状与算子不兼容

现象:某些算子在动态形状下行为不符合预期。

解决:检查导出时的 dynamic_axes 配置,某些操作可能不支持动态维度。

# 对不支持的算子固定维度
dynamic_axes = {
"input": {0: "batch_size"}, # 只让 batch_size 动态
# "input": {0: "batch", 2: "height", 3: "width"}, # 避免太多动态维度
}

问题四:类型不匹配

现象:推理时报类型错误。

原因:ONNX 对类型要求严格,某些操作需要特定类型。

解决

# 确保导出和推理时使用一致的数据类型
# PyTorch 导出
dummy_input = torch.randn(1, 3, 224, 224, dtype=torch.float32)

# ONNX Runtime 推理
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32) # 明确指定 float32

官方算子文档

最权威的算子参考是 ONNX 官方文档:

https://onnx.ai/onnx/operators/

每个算子的文档包含:

  • 算子简介和数学定义
  • 输入输出规范
  • 属性列表
  • 支持的数据类型
  • 版本变化历史

总结

理解 ONNX 算子集是成功部署模型的关键。遇到问题时:

  1. 首先确认模型使用的 Opset 版本
  2. 检查目标推理引擎支持的算子列表
  3. 使用 onnx-simplifier 尝试简化模型
  4. 必要时将不兼容的操作移到模型外部

下一章将介绍模型优化技术,进一步提升推理性能。