跳到主要内容

TensorFlow 导出 ONNX

除了 PyTorch,TensorFlow/Keras 也是广泛使用的深度学习框架。tf2onnx 是将 TensorFlow 模型转换为 ONNX 格式的官方工具,支持 SavedModel、Checkpoint、GraphDef、TFLite 等多种格式。

安装

# 安装 TensorFlow 和 tf2onnx
pip install tensorflow tf2onnx onnx onnxruntime

转换方式概览

tf2onnx 支持多种 TensorFlow 模型格式的转换:

格式命令参数特点
SavedModel--saved-model推荐,包含完整模型信息
Checkpoint--checkpoint需要指定输入输出
GraphDef--graphdef需要指定输入输出
TFLite--tflite移动端模型转换
TensorFlow.js--tfjsWeb 端模型转换

SavedModel 转换

SavedModel 是 TensorFlow 推荐的模型保存格式,包含了完整的模型结构和权重,转换最为方便。

命令行转换

# 基础转换
python -m tf2onnx.convert --saved-model ./saved_model_dir --output model.onnx

# 指定 Opset 版本
python -m tf2onnx.convert --saved-model ./saved_model_dir --opset 17 --output model.onnx

# 大模型(超过 2GB)
python -m tf2onnx.convert --saved-model ./saved_model_dir --large_model --output model.onnx

Python API 转换

import tensorflow as tf
import tf2onnx

# 加载 SavedModel
model = tf.saved_model.load("./saved_model_dir")

# 转换为 ONNX
model_proto, external_tensor_storage = tf2onnx.convert.from_saved_model(
"./saved_model_dir",
opset=17,
output_path="model.onnx"
)

print(f"转换成功,输出文件: model.onnx")

Keras 模型转换

Keras 是 TensorFlow 的高级 API,模型转换更加简洁。

命令行转换

# 先保存为 SavedModel,再转换
# TensorFlow 2.x 推荐方式
python -c "
import tensorflow as tf
model = tf.keras.models.load_model('model.h5')
tf.saved_model.save(model, './keras_saved_model')
"

python -m tf2onnx.convert --saved-model ./keras_saved_model --output model.onnx

Python API 转换

import tensorflow as tf
import tf2onnx
import numpy as np

# 创建或加载 Keras 模型
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(224, 224, 3)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10, activation='softmax')
])

# 编译模型(可选,用于生成输入签名)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

# 定义输入签名
input_signature = [tf.TensorSpec([None, 224, 224, 3], tf.float32, name="input")]

# 转换为 ONNX
model_proto, _ = tf2onnx.convert.from_keras(
model,
input_signature=input_signature,
opset=17,
output_path="model.onnx"
)

print("Keras 模型转换完成")

自定义 Keras 模型转换

import tensorflow as tf
import tf2onnx

class CustomModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu')
self.pool = tf.keras.layers.GlobalAveragePooling2D()
self.fc = tf.keras.layers.Dense(10)

def call(self, x):
x = self.conv1(x)
x = self.pool(x)
return self.fc(x)

model = CustomModel()

# 构建模型(确定输入形状)
model.build(input_shape=(None, 224, 224, 3))

# 转换
input_signature = [tf.TensorSpec([None, 224, 224, 3], tf.float32, name="input")]
model_proto, _ = tf2onnx.convert.from_keras(
model,
input_signature=input_signature,
output_path="custom_model.onnx"
)

tf.function 转换

对于使用 @tf.function 装饰的函数,可以使用 from_function 方法:

import tensorflow as tf
import tf2onnx

class MyModel(tf.Module):
def __init__(self):
self.weights = tf.Variable(tf.random.normal([10, 5]))

@tf.function(input_signature=[tf.TensorSpec([None, 10], tf.float32)])
def __call__(self, x):
return tf.matmul(x, self.weights)

model = MyModel()

# 转换
model_proto, _ = tf2onnx.convert.from_function(
model.__call__,
input_signature=[tf.TensorSpec([None, 10], tf.float32, name="input")],
output_path="model.onnx"
)

Checkpoint 转换

Checkpoint 格式只保存权重,需要提供输入输出节点名称:

# 需要知道输入输出节点的名称
python -m tf2onnx.convert \
--checkpoint model.ckpt.meta \
--output model.onnx \
--inputs "input:0" \
--outputs "output:0"

使用 summarize_graph 工具

如果不知道输入输出节点名称,可以使用 TensorFlow 的 summarize_graph 工具:

# 构建工具(需要从源码编译 TensorFlow)
bazel build tensorflow/tools/graph_transforms:summarize_graph

# 分析模型
./summarize_graph --in_graph=model.pb

GraphDef 转换

GraphDef 是 TensorFlow 的图定义格式(.pb 文件):

python -m tf2onnx.convert \
--graphdef model.pb \
--output model.onnx \
--inputs "input:0[1,224,224,3]" \
--outputs "output:0"

注意输入名称后面的 [1,224,224,3] 用于指定形状(可选)。

Python API

import tf2onnx
import tensorflow as tf

# 加载 GraphDef
graph_def = tf.compat.v1.GraphDef()
with tf.io.gfile.GFile("model.pb", "rb") as f:
graph_def.ParseFromString(f.read())

# 转换
model_proto, _ = tf2onnx.convert.from_graph_def(
graph_def,
input_names=["input:0"],
output_names=["output:0"],
output_path="model.onnx"
)

TFLite 转换

TFLite 是 TensorFlow 的移动端模型格式:

python -m tf2onnx.convert --tflite model.tflite --output model.onnx

Python API

import tf2onnx

model_proto, _ = tf2onnx.convert.from_tflite(
"model.tflite",
output_path="model.onnx"
)

量化 TFLite 模型

对于量化的 TFLite 模型,可以使用 --dequantize 参数转换为 float32:

python -m tf2onnx.convert --tflite quantized.tflite --dequantize --output model.onnx

高级选项

数据格式转换 (NHWC vs NCHW)

TensorFlow 默认使用 NHWC 格式(batch, height, width, channels),而 ONNX 使用 NCHW。tf2onnx 可以自动处理转换:

# 输入保持 NHWC(与 TensorFlow 一致)
python -m tf2onnx.convert --saved-model ./model --output model.onnx

# 输入转换为 NCHW
python -m tf2onnx.convert --saved-model ./model --output model.onnx \
--inputs-as-nchw input:0
# Python API
model_proto, _ = tf2onnx.convert.from_saved_model(
"./saved_model",
inputs_as_nchw=["input"],
output_path="model.onnx"
)

大模型处理

超过 2GB 的模型需要使用外部张量存储:

python -m tf2onnx.convert --saved-model ./large_model --large_model --output model.onnx

这会生成两个文件:

  • model.onnx:模型结构
  • model.onnx.data:外部张量数据

目标平台特定优化

# 针对 Windows ML 优化
python -m tf2onnx.convert --saved-model ./model --target winml --output model.onnx

自定义算子

# 指定自定义算子域
python -m tf2onnx.convert --saved-model ./model \
--custom-ops MyCustomOp:my.domain \
--output model.onnx

转换后验证

import numpy as np
import tensorflow as tf
import onnxruntime as ort

def verify_conversion(tf_model_path, onnx_model_path, input_shape=(1, 224, 224, 3)):
# 加载 TensorFlow 模型
tf_model = tf.saved_model.load(tf_model_path)

# 创建 ONNX Runtime 会话
ort_session = ort.InferenceSession(onnx_model_path)

# 准备测试数据
input_data = np.random.randn(*input_shape).astype(np.float32)

# TensorFlow 推理
tf_output = tf_model(input_data)
if isinstance(tf_output, dict):
tf_output = list(tf_output.values())[0]
tf_output = tf_output.numpy()

# ONNX 推理
input_name = ort_session.get_inputs()[0].name
ort_output = ort_session.run(None, {input_name: input_data})[0]

# 比较结果
max_diff = np.abs(tf_output - ort_output).max()
print(f"最大差异: {max_diff}")

if max_diff < 1e-5:
print("验证通过")
else:
print("警告:存在显著差异")

return max_diff

verify_conversion("./saved_model", "model.onnx")

常见问题

问题一:找不到输入输出节点

症状

ValueError: No inputs specified

解决:使用 SavedModel 格式,或通过 saved_model_cli 工具查看:

saved_model_cli show --dir ./saved_model --all

问题二:不支持的算子

症状

Unsupported ops: MyCustomOp

解决

  1. 检查是否有等效的标准算子组合
  2. 使用 --custom-ops 标记自定义算子
  3. 在推理引擎中实现自定义算子

问题三:形状推断失败

症状

Shape inference failed

解决:明确指定输入形状:

python -m tf2onnx.convert --graphdef model.pb --output model.onnx \
--inputs "input:0[1,224,224,3]"

问题四:Opset 版本不兼容

症状:转换成功但在推理引擎中加载失败

解决:使用较低的 Opset 版本:

python -m tf2onnx.convert --saved-model ./model --opset 14 --output model.onnx

完整示例

以下是一个完整的 TensorFlow 模型训练、保存和转换的示例:

import tensorflow as tf
import tf2onnx
import numpy as np
import onnxruntime as ort

# 1. 创建模型
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(224, 224, 3)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, activation='relu'),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)

# 2. 训练模型(这里用随机数据示例)
x_train = np.random.randn(100, 224, 224, 3).astype(np.float32)
y_train = np.random.randint(0, 10, 100)

model.fit(x_train, y_train, epochs=1, verbose=0)

# 3. 保存为 SavedModel
tf.saved_model.save(model, "./my_model")

# 4. 转换为 ONNX
input_signature = [tf.TensorSpec([None, 224, 224, 3], tf.float32, name="input")]
model_proto, _ = tf2onnx.convert.from_keras(
model,
input_signature=input_signature,
opset=17,
output_path="my_model.onnx"
)

# 5. 验证转换
test_input = np.random.randn(1, 224, 224, 3).astype(np.float32)

# TensorFlow 推理
tf_output = model(test_input).numpy()

# ONNX 推理
ort_session = ort.InferenceSession("my_model.onnx")
ort_output = ort_session.run(None, {"input": test_input})[0]

# 比较
diff = np.abs(tf_output - ort_output).max()
print(f"最大差异: {diff}")

if diff < 1e-5:
print("转换成功!")

下一步

TensorFlow 模型成功转换为 ONNX 后,后续步骤与 PyTorch 导出的模型相同:

  • 使用 onnx-simplifier 简化模型
  • 进行量化优化
  • 在 Python 或 C++ 中部署推理

详细内容请参考后续章节。