ONNX:跨框架部署 ML 模型(不绑 PyTorch / TensorFlow)

起因

ML model train 在 PyTorch / TF / sklearn,部署面对:

  • 想跑在 CPU 而不是 GPU
  • 想 deploy 到 mobile / web / 嵌入式
  • 想要更小 / 更快的 runtime(PyTorch ~ 1 GB;ONNX Runtime ~ 50 MB)
  • 不想生产环境扛 PyTorch 依赖

ONNX (Open Neural Network Exchange) 是模型表示标准。
ONNX Runtime 是跑 ONNX 模型的 runtime(C++ 写,多 backend)。

train 用任何框架 → export ONNX → 用 ONNX Runtime 跑。

export

PyTorch:

import torch

model = MyModel()
model.load_state_dict(torch.load('model.pt'))
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
    model, dummy_input, 'model.onnx',
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
    opset_version=18,
)

dynamic_axes 让 batch dim 运行时可变。

sklearn:

from skl2onnx import to_onnx
onnx_model = to_onnx(sk_model, X_sample[:1])
with open('model.onnx', 'wb') as f:
    f.write(onnx_model.SerializeToString())

HuggingFace transformers:

optimum-cli export onnx --model bert-base-uncased ./onnx_out/

验证

import onnx
onnx.checker.check_model(onnx.load('model.onnx'))

import onnxruntime as ort
sess = ort.InferenceSession('model.onnx', providers=['CPUExecutionProvider'])

input_name = sess.get_inputs()[0].name
result = sess.run(None, {input_name: dummy_input.numpy()})
print(result[0].shape)

跑通 → ONNX 模型 ready。

跟 PyTorch 原 model 输出做 numerical 比较(tolerance 1e-5):

torch_out = model(dummy_input).detach().numpy()
onnx_out = sess.run(None, {input_name: dummy_input.numpy()})[0]
np.testing.assert_allclose(torch_out, onnx_out, rtol=1e-3, atol=1e-5)

跑 ONNX Runtime

CPU:

sess = ort.InferenceSession('model.onnx', providers=['CPUExecutionProvider'])

GPU (CUDA):

sess = ort.InferenceSession('model.onnx',
    providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])

CoreML (mac M 系列):

providers=['CoreMLExecutionProvider', 'CPUExecutionProvider']

WebAssembly (浏览器):

// onnxruntime-web
import * as ort from 'onnxruntime-web';
const session = await ort.InferenceSession.create('model.onnx');
const results = await session.run({ input: tensor });

同一 .onnx 文件,多 platform 复用。

性能 vs PyTorch

ResNet50 inference / batch=1 / single thread:

Runtime latency
PyTorch CPU 80 ms
ONNX Runtime CPU 45 ms
PyTorch (TorchScript) 60 ms
ONNX + OpenVINO backend 28 ms

ONNX Runtime 普遍比 PyTorch CPU 快 1.5-2x(graph 优化 + inference 专用)。

GPU 上差距小,PyTorch 也很快。

量化

from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic('model.onnx', 'model_int8.onnx', weight_type=QuantType.QInt8)

INT8 量化 → model size 4x 小 + CPU 2x 快(精度损 1-3%)。
mobile / edge 部署关键。

部署在 server

FROM python:3.12-slim
RUN pip install onnxruntime fastapi uvicorn numpy

COPY model.onnx app.py /app/
WORKDIR /app
CMD ["uvicorn", "app:app", "--host", "0.0.0.0"]
# app.py
from fastapi import FastAPI
import onnxruntime as ort
import numpy as np

app = FastAPI()
sess = ort.InferenceSession('model.onnx')

@app.post('/predict')
async def predict(data: dict):
    x = np.array(data['input'], dtype=np.float32)
    result = sess.run(None, {'input': x})[0]
    return {'output': result.tolist()}

image 大小:

  • PyTorch image: ~3 GB
  • ONNX Runtime image: ~150 MB

container 启动快 + 部署成本低。

浏览器跑 model(onnxruntime-web)

<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
<script>
const session = await ort.InferenceSession.create('/static/model.onnx');
const feeds = { input: new ort.Tensor('float32', data, [1, 3, 224, 224]) };
const results = await session.run(feeds);
console.log(results.output.data);
</script>

不需要 server,model 跑在用户浏览器。
- 数据不出端(隐私)
- 0 server cost
- 适合:图片分类 / 文本嵌入 / 小型模型

mobile 浏览器 4-bit 量化后跑 MobileNet 50 ms。

与 TFLite / CoreML 对比

ONNX Runtime TFLite CoreML TorchScript
跨平台 iOS only
Train 框架 TF PyTorch
性能 高(mobile) 极高(apple) 中高
工具链 复杂但灵 简单 简单 PyTorch 内置

iOS 跑:CoreML 最快。
Android:TFLite 或 ONNX Runtime。
跨平台 / server:ONNX。

transformers 适配

HuggingFace optimum 让 transformers 一键转 ONNX:

from optimum.onnxruntime import ORTModelForSequenceClassification

model = ORTModelForSequenceClassification.from_pretrained(
    'distilbert-base-uncased-finetuned-sst-2-english',
    export=True,
)
# 内部自动 export ONNX + 用 ORT 跑

API 跟 transformers 一样,performance 是 ORT 加成。

真实 case:BERT 文本分类部署

train: HF transformers + GPU。
原计划: PyTorch serve 在 t3.medium (2vCPU, 4GB)。

PyTorch model: 440 MB
推理 latency: 350 ms / request
RAM: 1.2 GB

转 ONNX + quantize INT8:

ONNX INT8 model: 110 MB
推理 latency: 80 ms / request
RAM: 350 MB

成本 / 性能都改善。同 instance 能跑 4x QPS。

不适合 ONNX 的场景

  • dynamic graph 复杂(控制流多):ONNX op 不全 cover,export 失败
  • custom op:必须写 ONNX custom op(C++)
  • 需要 train:ONNX Runtime 主要 inference(有 training 但弱)
  • frequent model update:ORT runtime load 慢,热更新麻烦

train 阶段不动 PyTorch;deploy 阶段转 ONNX。

踩过的坑

  1. opset version:老 export 用 opset 11,新 ORT 默认要 17+。
    不匹配 unsupported op。统一 opset 18+。

  2. dynamic shape:忘 dynamic_axes → batch=1 hardcoded。生产
    variable batch 报错。

  3. 数值不一致:FP16 / FP32 mix 后 train 和 ONNX 差 1%。生产
    numerical 严格场景小心。

  4. custom op:用了 PyTorch 特有 op (如 grid_sample 某些 mode)
    → ONNX export 报错。改 model 或者手写 ONNX op。

  5. runtime version mismatch:onnx 库版本 vs onnxruntime 版本不
    匹配 → load model 报错。pip 同一时间装。

精确评价 共 0 人评价
可复现性
可复现 · 0 不可复现 · 0
文风
文风流畅 · 0 文风晦涩 · 0
立场
支持 · 0 反对 · 0

登录后即可对本帖作出评价。

评论区 0 条 · 所有人可在此交流

登录后参与评论。

还没有评论,来说两句。