TensorFlow 导出 ONNX
除了 PyTorch,TensorFlow/Keras 也是广泛使用的深度学习框架。tf2onnx 是将 TensorFlow 模型转换为 ONNX 格式的官方工具,支持 SavedModel、Checkpoint、GraphDef、TFLite 等多种格式。
安装
# 安装 TensorFlow 和 tf2onnx
pip install tensorflow tf2onnx onnx onnxruntime
转换方式概览
tf2onnx 支持多种 TensorFlow 模型格式的转换:
| 格式 | 命令参数 | 特点 |
|---|---|---|
| SavedModel | --saved-model | 推荐,包含完整模型信息 |
| Checkpoint | --checkpoint | 需要指定输入输出 |
| GraphDef | --graphdef | 需要指定输入输出 |
| TFLite | --tflite | 移动端模型转换 |
| TensorFlow.js | --tfjs | Web 端模型转换 |
SavedModel 转换
SavedModel 是 TensorFlow 推荐的模型保存格式,包含了完整的模型结构和权重,转换最为方便。
命令行转换
# 基础转换
python -m tf2onnx.convert --saved-model ./saved_model_dir --output model.onnx
# 指定 Opset 版本
python -m tf2onnx.convert --saved-model ./saved_model_dir --opset 17 --output model.onnx
# 大模型(超过 2GB)
python -m tf2onnx.convert --saved-model ./saved_model_dir --large_model --output model.onnx
Python API 转换
import tensorflow as tf
import tf2onnx
# 加载 SavedModel
model = tf.saved_model.load("./saved_model_dir")
# 转换为 ONNX
model_proto, external_tensor_storage = tf2onnx.convert.from_saved_model(
"./saved_model_dir",
opset=17,
output_path="model.onnx"
)
print(f"转换成功,输出文件: model.onnx")
Keras 模型转换
Keras 是 TensorFlow 的高级 API,模型转换更加简洁。
命令行转换
# 先保存为 SavedModel,再转换
# TensorFlow 2.x 推荐方式
python -c "
import tensorflow as tf
model = tf.keras.models.load_model('model.h5')
tf.saved_model.save(model, './keras_saved_model')
"
python -m tf2onnx.convert --saved-model ./keras_saved_model --output model.onnx
Python API 转换
import tensorflow as tf
import tf2onnx
import numpy as np
# 创建或加载 Keras 模型
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(224, 224, 3)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10, activation='softmax')
])
# 编译模型(可选,用于生成输入签名)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
# 定义输入签名
input_signature = [tf.TensorSpec([None, 224, 224, 3], tf.float32, name="input")]
# 转换为 ONNX
model_proto, _ = tf2onnx.convert.from_keras(
model,
input_signature=input_signature,
opset=17,
output_path="model.onnx"
)
print("Keras 模型转换完成")
自定义 Keras 模型转换
import tensorflow as tf
import tf2onnx
class CustomModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu')
self.pool = tf.keras.layers.GlobalAveragePooling2D()
self.fc = tf.keras.layers.Dense(10)
def call(self, x):
x = self.conv1(x)
x = self.pool(x)
return self.fc(x)
model = CustomModel()
# 构建模型(确定输入形状)
model.build(input_shape=(None, 224, 224, 3))
# 转换
input_signature = [tf.TensorSpec([None, 224, 224, 3], tf.float32, name="input")]
model_proto, _ = tf2onnx.convert.from_keras(
model,
input_signature=input_signature,
output_path="custom_model.onnx"
)
tf.function 转换
对于使用 @tf.function 装饰的函数,可以使用 from_function 方法:
import tensorflow as tf
import tf2onnx
class MyModel(tf.Module):
def __init__(self):
self.weights = tf.Variable(tf.random.normal([10, 5]))
@tf.function(input_signature=[tf.TensorSpec([None, 10], tf.float32)])
def __call__(self, x):
return tf.matmul(x, self.weights)
model = MyModel()
# 转换
model_proto, _ = tf2onnx.convert.from_function(
model.__call__,
input_signature=[tf.TensorSpec([None, 10], tf.float32, name="input")],
output_path="model.onnx"
)
Checkpoint 转换
Checkpoint 格式只保存权重,需要提供输入输出节点名称:
# 需要知道输入输出节点的名称
python -m tf2onnx.convert \
--checkpoint model.ckpt.meta \
--output model.onnx \
--inputs "input:0" \
--outputs "output:0"
使用 summarize_graph 工具
如果不知道输入输出节点名称,可以使用 TensorFlow 的 summarize_graph 工具:
# 构建工具(需要从源码编译 TensorFlow)
bazel build tensorflow/tools/graph_transforms:summarize_graph
# 分析模型
./summarize_graph --in_graph=model.pb
GraphDef 转换
GraphDef 是 TensorFlow 的图定义格式(.pb 文件):
python -m tf2onnx.convert \
--graphdef model.pb \
--output model.onnx \
--inputs "input:0[1,224,224,3]" \
--outputs "output:0"
注意输入名称后面的 [1,224,224,3] 用于指定形状(可选)。
Python API
import tf2onnx
import tensorflow as tf
# 加载 GraphDef
graph_def = tf.compat.v1.GraphDef()
with tf.io.gfile.GFile("model.pb", "rb") as f:
graph_def.ParseFromString(f.read())
# 转换
model_proto, _ = tf2onnx.convert.from_graph_def(
graph_def,
input_names=["input:0"],
output_names=["output:0"],
output_path="model.onnx"
)
TFLite 转换
TFLite 是 TensorFlow 的移动端模型格式:
python -m tf2onnx.convert --tflite model.tflite --output model.onnx
Python API
import tf2onnx
model_proto, _ = tf2onnx.convert.from_tflite(
"model.tflite",
output_path="model.onnx"
)
量化 TFLite 模型
对于量化的 TFLite 模型,可以使用 --dequantize 参数转换为 float32:
python -m tf2onnx.convert --tflite quantized.tflite --dequantize --output model.onnx
高级选项
数据格式转换 (NHWC vs NCHW)
TensorFlow 默认使用 NHWC 格式(batch, height, width, channels),而 ONNX 使用 NCHW。tf2onnx 可以自动处理转换:
# 输入保持 NHWC(与 TensorFlow 一致)
python -m tf2onnx.convert --saved-model ./model --output model.onnx
# 输入转换为 NCHW
python -m tf2onnx.convert --saved-model ./model --output model.onnx \
--inputs-as-nchw input:0
# Python API
model_proto, _ = tf2onnx.convert.from_saved_model(
"./saved_model",
inputs_as_nchw=["input"],
output_path="model.onnx"
)
大模型处理
超过 2GB 的模型需要使用外部张量存储:
python -m tf2onnx.convert --saved-model ./large_model --large_model --output model.onnx
这会生成两个文件:
model.onnx:模型结构model.onnx.data:外部张量数据
目标平台特定优化
# 针对 Windows ML 优化
python -m tf2onnx.convert --saved-model ./model --target winml --output model.onnx
自定义算子
# 指定自定义算子域
python -m tf2onnx.convert --saved-model ./model \
--custom-ops MyCustomOp:my.domain \
--output model.onnx
转换后验证
import numpy as np
import tensorflow as tf
import onnxruntime as ort
def verify_conversion(tf_model_path, onnx_model_path, input_shape=(1, 224, 224, 3)):
# 加载 TensorFlow 模型
tf_model = tf.saved_model.load(tf_model_path)
# 创建 ONNX Runtime 会话
ort_session = ort.InferenceSession(onnx_model_path)
# 准备测试数据
input_data = np.random.randn(*input_shape).astype(np.float32)
# TensorFlow 推理
tf_output = tf_model(input_data)
if isinstance(tf_output, dict):
tf_output = list(tf_output.values())[0]
tf_output = tf_output.numpy()
# ONNX 推理
input_name = ort_session.get_inputs()[0].name
ort_output = ort_session.run(None, {input_name: input_data})[0]
# 比较结果
max_diff = np.abs(tf_output - ort_output).max()
print(f"最大差异: {max_diff}")
if max_diff < 1e-5:
print("验证通过")
else:
print("警告:存在显著差异")
return max_diff
verify_conversion("./saved_model", "model.onnx")
常见问题
问题一:找不到输入输出节点
症状:
ValueError: No inputs specified
解决:使用 SavedModel 格式,或通过 saved_model_cli 工具查看:
saved_model_cli show --dir ./saved_model --all
问题二:不支持的算子
症状:
Unsupported ops: MyCustomOp
解决:
- 检查是否有等效的标准算子组合
- 使用
--custom-ops标记自定义算子 - 在推理引擎中实现自定义算子
问题三:形状推断失败
症状:
Shape inference failed
解决:明确指定输入形状:
python -m tf2onnx.convert --graphdef model.pb --output model.onnx \
--inputs "input:0[1,224,224,3]"
问题四:Opset 版本不兼容
症状:转换成功但在推理引擎中加载失败
解决:使用较低的 Opset 版本:
python -m tf2onnx.convert --saved-model ./model --opset 14 --output model.onnx
完整示例
以下是一个完整的 TensorFlow 模型训练、保存和转换的示例:
import tensorflow as tf
import tf2onnx
import numpy as np
import onnxruntime as ort
# 1. 创建模型
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(224, 224, 3)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, activation='relu'),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# 2. 训练模型(这里用随机数据示例)
x_train = np.random.randn(100, 224, 224, 3).astype(np.float32)
y_train = np.random.randint(0, 10, 100)
model.fit(x_train, y_train, epochs=1, verbose=0)
# 3. 保存为 SavedModel
tf.saved_model.save(model, "./my_model")
# 4. 转换为 ONNX
input_signature = [tf.TensorSpec([None, 224, 224, 3], tf.float32, name="input")]
model_proto, _ = tf2onnx.convert.from_keras(
model,
input_signature=input_signature,
opset=17,
output_path="my_model.onnx"
)
# 5. 验证转换
test_input = np.random.randn(1, 224, 224, 3).astype(np.float32)
# TensorFlow 推理
tf_output = model(test_input).numpy()
# ONNX 推理
ort_session = ort.InferenceSession("my_model.onnx")
ort_output = ort_session.run(None, {"input": test_input})[0]
# 比较
diff = np.abs(tf_output - ort_output).max()
print(f"最大差异: {diff}")
if diff < 1e-5:
print("转换成功!")
下一步
TensorFlow 模型成功转换为 ONNX 后,后续步骤与 PyTorch 导出的模型相同:
- 使用 onnx-simplifier 简化模型
- 进行量化优化
- 在 Python 或 C++ 中部署推理
详细内容请参考后续章节。