训练用 PyTorch 灵活,部署到生产时通常希望:
- 没有 PyTorch 100+ MB 依赖
- 跨语言(C++ / Go / JS / Java 都能加载)
- CPU / GPU 都能跑
- 性能更好(融合算子)
ONNX 是开放神经网络交换格式,ONNX Runtime 是 Microsoft 的高性能推理引擎。
PyTorch 训练完导出 ONNX,运行时用 ORT 加载。
装
uv add torch onnx onnxruntime
# GPU 推理:
uv add onnxruntime-gpu
1. 导出 PyTorch 模型为 ONNX
import torch
from your_model import Net
model = Net()
model.load_state_dict(torch.load('mnist.pt'))
model.eval()
# 一个 dummy 输入用于 trace
dummy = torch.randn(1, 1, 28, 28)
torch.onnx.export(
model,
dummy,
'mnist.onnx',
input_names=['input'],
output_names=['logits'],
dynamic_axes={
'input': {0: 'batch'}, # batch 维度可变
'logits': {0: 'batch'},
},
opset_version=17,
)
dynamic_axes 让导出的模型支持任意 batch size,否则固定为 dummy 的形状。
2. 校验导出正确
import onnx
m = onnx.load('mnist.onnx')
onnx.checker.check_model(m)
print(onnx.helper.printable_graph(m.graph))
check_model 不报错就 OK。
3. 推理(Python)
import onnxruntime as ort
import numpy as np
sess = ort.InferenceSession('mnist.onnx', providers=['CPUExecutionProvider'])
# GPU: providers=['CUDAExecutionProvider']
# 看输入输出
for i in sess.get_inputs():
print(f'input {i.name}: {i.shape} {i.type}')
for o in sess.get_outputs():
print(f'output {o.name}: {o.shape} {o.type}')
# 跑推理
x = np.random.rand(4, 1, 28, 28).astype(np.float32)
logits = sess.run(['logits'], {'input': x})[0]
pred = logits.argmax(axis=1)
print(pred)
4. 性能基准
import time, numpy as np
x = np.random.rand(1, 1, 28, 28).astype(np.float32)
# warm up
for _ in range(10): sess.run(['logits'], {'input': x})
t0 = time.time()
for _ in range(1000): sess.run(['logits'], {'input': x})
print(f'avg latency: {(time.time()-t0)/1000*1000:.2f} ms')
对比 PyTorch:
import torch
model.eval()
x = torch.randn(1, 1, 28, 28)
with torch.no_grad():
for _ in range(10): model(x)
t0 = time.time()
for _ in range(1000): model(x)
print(f'pytorch avg: {(time.time()-t0)/1000*1000:.2f} ms')
CPU 上 ONNX Runtime 通常比 PyTorch 快 1.5-3x(算子融合 + 简化 graph)。
5. 性能优化
sess = ort.InferenceSession(
'mnist.onnx',
providers=['CPUExecutionProvider'],
sess_options=ort.SessionOptions(),
)
opts = ort.SessionOptions()
opts.intra_op_num_threads = 4
opts.inter_op_num_threads = 1
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
sess = ort.InferenceSession('mnist.onnx', sess_options=opts,
providers=['CPUExecutionProvider'])
对于 transformer / 大模型,enable_profiling=True 让 ORT 输出每个算子的耗时
帮助找瓶颈。
6. 多 provider
# 优先 GPU,没有就 CPU
sess = ort.InferenceSession(
'mnist.onnx',
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
print(sess.get_providers())
NVIDIA:CUDAExecutionProvider / TensorrtExecutionProvider
Apple:CoreMLExecutionProvider
AMD:ROCMExecutionProvider
Intel:OpenVINOExecutionProvider
7. 量化(减小模型 + 加速 CPU 推理)
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic('mnist.onnx', 'mnist.int8.onnx', weight_type=QuantType.QInt8)
INT8 量化通常 2-4x 推理加速 + 模型大小 1/4。精度损失对 ResNet / 简单 CNN
很小(< 1% 准确率),对 BERT 类需要 calibration 复杂些。
8. C++ 推理
#include "onnxruntime_cxx_api.h"
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
Ort::SessionOptions opts;
Ort::Session session(env, "mnist.onnx", opts);
// ... 准备输入 tensor ...
auto output = session.Run(...);
完整 C++ 例子在 ONNX Runtime 仓库。集成到 C++ 服务里完全摆脱 Python 依赖。
9. 浏览器推理(onnxruntime-web)
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
<script>
const session = await ort.InferenceSession.create('mnist.onnx');
const input = new ort.Tensor('float32', data, [1, 1, 28, 28]);
const results = await session.run({ input: input });
console.log(results.logits.data);
</script>
WebGPU / WebGL / wasm 多后端自动选。小模型直接跑浏览器,数据不离开
用户设备。
10. 部署模式比较
| 工具 | 适合 |
|---|---|
| ONNX Runtime | 跨语言、跨平台、单进程 |
| TorchServe | 多 PyTorch 模型微服务 |
| Triton Inference Server | NVIDIA GPU 多模型高并发 |
| BentoML | Python 服务封装(含监控 / 队列) |
| vLLM | LLM 专用(PagedAttention) |
ONNX Runtime 是"最通用最简单"那档;要更高级 ops(动态 batching、
gpu 调度)上 Triton。
踩过的坑
- 导出失败 "Unsupported ONNX opset version":升
opset_version或者
降 PyTorch 中用的 op(替换 custom op)。 - 导出后形状对但数值差:训练时 BatchNorm 等 running stats 没保存好,
确保model.eval()后再 export。 - Dynamic axes 没写:onnx 模型固定 batch=1,部署时 batch=N 直接报错。
- ONNX 模型文件很大(含权重):考虑用
onnx.save_model(m, ..., save_as_external_data=True)
把权重分离存储,便于 CDN / 分发。
登录后参与评论。