用 ONNX Runtime 部署 PyTorch 模型(CPU / GPU 通用、跨语言)

训练用 PyTorch 灵活,部署到生产时通常希望:

  • 没有 PyTorch 100+ MB 依赖
  • 跨语言(C++ / Go / JS / Java 都能加载)
  • CPU / GPU 都能跑
  • 性能更好(融合算子)

ONNX 是开放神经网络交换格式,ONNX Runtime 是 Microsoft 的高性能推理引擎。
PyTorch 训练完导出 ONNX,运行时用 ORT 加载。

uv add torch onnx onnxruntime
# GPU 推理:
uv add onnxruntime-gpu

1. 导出 PyTorch 模型为 ONNX

import torch
from your_model import Net

model = Net()
model.load_state_dict(torch.load('mnist.pt'))
model.eval()

# 一个 dummy 输入用于 trace
dummy = torch.randn(1, 1, 28, 28)

torch.onnx.export(
    model,
    dummy,
    'mnist.onnx',
    input_names=['input'],
    output_names=['logits'],
    dynamic_axes={
        'input': {0: 'batch'},   # batch 维度可变
        'logits': {0: 'batch'},
    },
    opset_version=17,
)

dynamic_axes 让导出的模型支持任意 batch size,否则固定为 dummy 的形状。

2. 校验导出正确

import onnx
m = onnx.load('mnist.onnx')
onnx.checker.check_model(m)
print(onnx.helper.printable_graph(m.graph))

check_model 不报错就 OK。

3. 推理(Python)

import onnxruntime as ort
import numpy as np

sess = ort.InferenceSession('mnist.onnx', providers=['CPUExecutionProvider'])
# GPU: providers=['CUDAExecutionProvider']

# 看输入输出
for i in sess.get_inputs():
    print(f'input  {i.name}: {i.shape} {i.type}')
for o in sess.get_outputs():
    print(f'output {o.name}: {o.shape} {o.type}')

# 跑推理
x = np.random.rand(4, 1, 28, 28).astype(np.float32)
logits = sess.run(['logits'], {'input': x})[0]
pred = logits.argmax(axis=1)
print(pred)

4. 性能基准

import time, numpy as np
x = np.random.rand(1, 1, 28, 28).astype(np.float32)

# warm up
for _ in range(10): sess.run(['logits'], {'input': x})

t0 = time.time()
for _ in range(1000): sess.run(['logits'], {'input': x})
print(f'avg latency: {(time.time()-t0)/1000*1000:.2f} ms')

对比 PyTorch:

import torch
model.eval()
x = torch.randn(1, 1, 28, 28)
with torch.no_grad():
    for _ in range(10): model(x)
    t0 = time.time()
    for _ in range(1000): model(x)
print(f'pytorch avg: {(time.time()-t0)/1000*1000:.2f} ms')

CPU 上 ONNX Runtime 通常比 PyTorch 快 1.5-3x(算子融合 + 简化 graph)。

5. 性能优化

sess = ort.InferenceSession(
    'mnist.onnx',
    providers=['CPUExecutionProvider'],
    sess_options=ort.SessionOptions(),
)

opts = ort.SessionOptions()
opts.intra_op_num_threads = 4
opts.inter_op_num_threads = 1
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL

sess = ort.InferenceSession('mnist.onnx', sess_options=opts,
                            providers=['CPUExecutionProvider'])

对于 transformer / 大模型,enable_profiling=True 让 ORT 输出每个算子的耗时
帮助找瓶颈。

6. 多 provider

# 优先 GPU,没有就 CPU
sess = ort.InferenceSession(
    'mnist.onnx',
    providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
print(sess.get_providers())

NVIDIA:CUDAExecutionProvider / TensorrtExecutionProvider
Apple:CoreMLExecutionProvider
AMD:ROCMExecutionProvider
Intel:OpenVINOExecutionProvider

7. 量化(减小模型 + 加速 CPU 推理)

from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic('mnist.onnx', 'mnist.int8.onnx', weight_type=QuantType.QInt8)

INT8 量化通常 2-4x 推理加速 + 模型大小 1/4。精度损失对 ResNet / 简单 CNN
很小(< 1% 准确率),对 BERT 类需要 calibration 复杂些。

8. C++ 推理

#include "onnxruntime_cxx_api.h"

Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
Ort::SessionOptions opts;
Ort::Session session(env, "mnist.onnx", opts);

// ... 准备输入 tensor ...
auto output = session.Run(...);

完整 C++ 例子在 ONNX Runtime 仓库。集成到 C++ 服务里完全摆脱 Python 依赖。

9. 浏览器推理(onnxruntime-web)

<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
<script>
  const session = await ort.InferenceSession.create('mnist.onnx');
  const input = new ort.Tensor('float32', data, [1, 1, 28, 28]);
  const results = await session.run({ input: input });
  console.log(results.logits.data);
</script>

WebGPU / WebGL / wasm 多后端自动选。小模型直接跑浏览器,数据不离开
用户设备。

10. 部署模式比较

工具 适合
ONNX Runtime 跨语言、跨平台、单进程
TorchServe 多 PyTorch 模型微服务
Triton Inference Server NVIDIA GPU 多模型高并发
BentoML Python 服务封装(含监控 / 队列)
vLLM LLM 专用(PagedAttention)

ONNX Runtime 是"最通用最简单"那档;要更高级 ops(动态 batching、
gpu 调度)上 Triton。

踩过的坑

  • 导出失败 "Unsupported ONNX opset version":升 opset_version 或者
    降 PyTorch 中用的 op(替换 custom op)。
  • 导出后形状对但数值差:训练时 BatchNorm 等 running stats 没保存好,
    确保 model.eval() 后再 export。
  • Dynamic axes 没写:onnx 模型固定 batch=1,部署时 batch=N 直接报错。
  • ONNX 模型文件很大(含权重):考虑用 onnx.save_model(m, ..., save_as_external_data=True)
    把权重分离存储,便于 CDN / 分发。
精确评价 共 0 人评价
可复现性
可复现 · 0 不可复现 · 0
文风
文风流畅 · 0 文风晦涩 · 0
立场
支持 · 0 反对 · 0

登录后即可对本帖作出评价。

评论区 0 条 · 所有人可在此交流

登录后参与评论。

还没有评论,来说两句。