起因
ML model train 在 PyTorch / TF / sklearn,部署面对:
- 想跑在 CPU 而不是 GPU
- 想 deploy 到 mobile / web / 嵌入式
- 想要更小 / 更快的 runtime(PyTorch ~ 1 GB;ONNX Runtime ~ 50 MB)
- 不想生产环境扛 PyTorch 依赖
ONNX (Open Neural Network Exchange) 是模型表示标准。
ONNX Runtime 是跑 ONNX 模型的 runtime(C++ 写,多 backend)。
train 用任何框架 → export ONNX → 用 ONNX Runtime 跑。
export
PyTorch:
import torch
model = MyModel()
model.load_state_dict(torch.load('model.pt'))
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model, dummy_input, 'model.onnx',
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
opset_version=18,
)
dynamic_axes 让 batch dim 运行时可变。
sklearn:
from skl2onnx import to_onnx
onnx_model = to_onnx(sk_model, X_sample[:1])
with open('model.onnx', 'wb') as f:
f.write(onnx_model.SerializeToString())
HuggingFace transformers:
optimum-cli export onnx --model bert-base-uncased ./onnx_out/
验证
import onnx
onnx.checker.check_model(onnx.load('model.onnx'))
import onnxruntime as ort
sess = ort.InferenceSession('model.onnx', providers=['CPUExecutionProvider'])
input_name = sess.get_inputs()[0].name
result = sess.run(None, {input_name: dummy_input.numpy()})
print(result[0].shape)
跑通 → ONNX 模型 ready。
跟 PyTorch 原 model 输出做 numerical 比较(tolerance 1e-5):
torch_out = model(dummy_input).detach().numpy()
onnx_out = sess.run(None, {input_name: dummy_input.numpy()})[0]
np.testing.assert_allclose(torch_out, onnx_out, rtol=1e-3, atol=1e-5)
跑 ONNX Runtime
CPU:
sess = ort.InferenceSession('model.onnx', providers=['CPUExecutionProvider'])
GPU (CUDA):
sess = ort.InferenceSession('model.onnx',
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
CoreML (mac M 系列):
providers=['CoreMLExecutionProvider', 'CPUExecutionProvider']
WebAssembly (浏览器):
// onnxruntime-web
import * as ort from 'onnxruntime-web';
const session = await ort.InferenceSession.create('model.onnx');
const results = await session.run({ input: tensor });
同一 .onnx 文件,多 platform 复用。
性能 vs PyTorch
ResNet50 inference / batch=1 / single thread:
| Runtime | latency |
|---|---|
| PyTorch CPU | 80 ms |
| ONNX Runtime CPU | 45 ms |
| PyTorch (TorchScript) | 60 ms |
| ONNX + OpenVINO backend | 28 ms |
ONNX Runtime 普遍比 PyTorch CPU 快 1.5-2x(graph 优化 + inference 专用)。
GPU 上差距小,PyTorch 也很快。
量化
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic('model.onnx', 'model_int8.onnx', weight_type=QuantType.QInt8)
INT8 量化 → model size 4x 小 + CPU 2x 快(精度损 1-3%)。
mobile / edge 部署关键。
部署在 server
FROM python:3.12-slim
RUN pip install onnxruntime fastapi uvicorn numpy
COPY model.onnx app.py /app/
WORKDIR /app
CMD ["uvicorn", "app:app", "--host", "0.0.0.0"]
# app.py
from fastapi import FastAPI
import onnxruntime as ort
import numpy as np
app = FastAPI()
sess = ort.InferenceSession('model.onnx')
@app.post('/predict')
async def predict(data: dict):
x = np.array(data['input'], dtype=np.float32)
result = sess.run(None, {'input': x})[0]
return {'output': result.tolist()}
image 大小:
- PyTorch image: ~3 GB
- ONNX Runtime image: ~150 MB
container 启动快 + 部署成本低。
浏览器跑 model(onnxruntime-web)
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
<script>
const session = await ort.InferenceSession.create('/static/model.onnx');
const feeds = { input: new ort.Tensor('float32', data, [1, 3, 224, 224]) };
const results = await session.run(feeds);
console.log(results.output.data);
</script>
不需要 server,model 跑在用户浏览器。
- 数据不出端(隐私)
- 0 server cost
- 适合:图片分类 / 文本嵌入 / 小型模型
mobile 浏览器 4-bit 量化后跑 MobileNet 50 ms。
与 TFLite / CoreML 对比
| ONNX Runtime | TFLite | CoreML | TorchScript | |
|---|---|---|---|---|
| 跨平台 | 强 | 强 | iOS only | 中 |
| Train 框架 | 全 | TF | 全 | PyTorch |
| 性能 | 高 | 高(mobile) | 极高(apple) | 中高 |
| 工具链 | 复杂但灵 | 简单 | 简单 | PyTorch 内置 |
iOS 跑:CoreML 最快。
Android:TFLite 或 ONNX Runtime。
跨平台 / server:ONNX。
transformers 适配
HuggingFace optimum 让 transformers 一键转 ONNX:
from optimum.onnxruntime import ORTModelForSequenceClassification
model = ORTModelForSequenceClassification.from_pretrained(
'distilbert-base-uncased-finetuned-sst-2-english',
export=True,
)
# 内部自动 export ONNX + 用 ORT 跑
API 跟 transformers 一样,performance 是 ORT 加成。
真实 case:BERT 文本分类部署
train: HF transformers + GPU。
原计划: PyTorch serve 在 t3.medium (2vCPU, 4GB)。
PyTorch model: 440 MB
推理 latency: 350 ms / request
RAM: 1.2 GB
转 ONNX + quantize INT8:
ONNX INT8 model: 110 MB
推理 latency: 80 ms / request
RAM: 350 MB
成本 / 性能都改善。同 instance 能跑 4x QPS。
不适合 ONNX 的场景
- dynamic graph 复杂(控制流多):ONNX op 不全 cover,export 失败
- custom op:必须写 ONNX custom op(C++)
- 需要 train:ONNX Runtime 主要 inference(有 training 但弱)
- frequent model update:ORT runtime load 慢,热更新麻烦
train 阶段不动 PyTorch;deploy 阶段转 ONNX。
踩过的坑
-
opset version:老 export 用 opset 11,新 ORT 默认要 17+。
不匹配 unsupported op。统一 opset 18+。 -
dynamic shape:忘
dynamic_axes→ batch=1 hardcoded。生产
variable batch 报错。 -
数值不一致:FP16 / FP32 mix 后 train 和 ONNX 差 1%。生产
numerical 严格场景小心。 -
custom op:用了 PyTorch 特有 op (如 grid_sample 某些 mode)
→ ONNX export 报错。改 model 或者手写 ONNX op。 -
runtime version mismatch:onnx 库版本 vs onnxruntime 版本不
匹配 → load model 报错。pip 同一时间装。
登录后参与评论。