PyTorch 模型 INT8 量化:模型小 4x、推理快 2-4x、精度损失 < 1%

起因

要把一个 30M 参数的 ResNet 部署到 ARM 手机上。FP32 模型 120 MB +
推理慢得卡顿。INT8 量化把模型缩到 30 MB + 推理快 3 倍,精度只掉 0.5%。
深度学习模型 INT8 量化已成熟,几行代码搞定。

三种量化策略

A. Dynamic quantization:仅 weight 量化,最简单

import torch
from torchvision.models import resnet50

model = resnet50(pretrained=True).eval()

# 一行:把所有 Linear 层 weight 量化到 INT8
quantized = torch.quantization.quantize_dynamic(
    model,
    {torch.nn.Linear, torch.nn.LSTM, torch.nn.RNN},
    dtype=torch.qint8,
)

torch.save(quantized.state_dict(), 'resnet50_int8.pt')

适合:BERT / transformer / RNN 类(Linear 主导)。
不适合:CNN(Conv2d 占比大,dynamic 不量化它)。

B. Static quantization:weight + activation 全量化

需要 calibration(用代表性数据跑一遍找 activation 范围):

import torch
import torch.ao.quantization as Q

model = MyModel().eval()

# 1. 准备:插入 observer 收 activation 统计
model.qconfig = Q.get_default_qconfig('fbgemm')   # x86; 'qnnpack' for ARM
model_prepared = Q.prepare(model, inplace=False)

# 2. Calibration: 跑 ~100-1000 张代表性图片
with torch.no_grad():
    for img in calibration_loader:
        model_prepared(img)

# 3. Convert: observer → 实际 quant op
model_int8 = Q.convert(model_prepared, inplace=False)

torch.save(model_int8.state_dict(), 'resnet50_static_int8.pt')

效果通常 比 dynamic 更激进,全模型 INT8
代价:要 calibration data + 模型架构必须支持(含 Conv-BN fuse 等)。

C. Quantization-aware training (QAT):训练时模拟量化

精度损失最小(< 0.5%)但要重训:

model.qconfig = Q.get_default_qat_qconfig('fbgemm')
model_prepared = Q.prepare_qat(model, inplace=False)

# 训练(模型在前向时模拟 INT8 round 噪声)
for epoch in range(5):
    for x, y in train_loader:
        loss = criterion(model_prepared(x), y)
        loss.backward()
        optimizer.step()

model_int8 = Q.convert(model_prepared.eval())

QAT 适合:精度对生产关键的模型,可承受 5-20 epoch 重训。

实测对比(ResNet50 + ImageNet val)

大小 CPU 推理(ms/img) top-1 acc
FP32 98 MB 65 76.13%
Dynamic INT8 ~95 MB 65 76.13% (Linear 没主导)
Static INT8 25 MB 28 75.84%
QAT INT8 25 MB 28 76.02%

CNN 类 static / QAT 显著有效。BERT 类 dynamic 也能 4x 小 + 2-3x 快。

ONNX Runtime + INT8(生产推荐)

PyTorch 量化导出 ONNX 后用 ONNX Runtime 跑,性能 / 跨平台都更好:

import torch
from torch.ao.quantization import quantize_dynamic

q_model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

dummy = torch.randn(1, 3, 224, 224)
torch.onnx.export(q_model, dummy, 'model_int8.onnx', opset_version=13)

或者直接 ONNX Runtime 的量化工具(更稳):

from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
    model_input='model_fp32.onnx',
    model_output='model_int8.onnx',
    weight_type=QuantType.QInt8,
)

ONNX Runtime 在 ARM / x86 / Apple Silicon 都有 INT8 优化 kernel。

bitsandbytes:LLM 用 4-bit / 8-bit quantization

uv add bitsandbytes accelerate
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    'meta-llama/Llama-3.1-70B-Instruct',
    quantization_config=bnb_config,
    device_map='auto',
)

70B 模型从 140 GB → 35 GB。单 A100 80GB 或 4090 + offload 跑得动。

精度损失:相对 FP16 通常 < 1% benchmark(NF4 比 INT4 更稳)。

GPTQ / AWQ:post-training quantization for LLM

针对 LLM 优化的 4-bit 量化算法(比 bitsandbytes NF4 更好):

# GPTQ
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig

quantize_config = BaseQuantizeConfig(bits=4, group_size=128)
model = AutoGPTQForCausalLM.from_pretrained(
    'meta-llama/Llama-3.1-8B',
    quantize_config=quantize_config,
)
model.quantize(calibration_dataset)
model.save_quantized('llama-3.1-8b-4bit-gptq')
# AWQ
from awq import AutoAWQForCausalLM

model = AutoAWQForCausalLM.from_pretrained('llama-3.1-8b')
model.quantize(tokenizer, quant_config={...})

社区已有大量预量化的 GPTQ / AWQ model 在 HuggingFace(搜
-GPTQ-4bit / -AWQ 后缀)。直接下载用,省得自己量化。

部署侧:vLLM / llama.cpp 用量化模型

# vLLM
vllm serve TheBloke/Llama-3.1-70B-AWQ --quantization awq

# llama.cpp(CPU / Apple Silicon Metal 极快)
llama-cli -m llama-3.1-8b.Q4_K_M.gguf -p 'hello'

llama.cpp GGUF 格式包含量化(Q4_K_M / Q5_K_M / Q8_0 等),
Mac M 系列上 8B 模型 30+ tokens/s。

效果

我们的几个生产模型量化后:

模型 之前 量化后 精度损失
ResNet50 (移动 app) 98 MB / 65ms 25 MB / 22ms -0.3%
BERT-base (后台) 440 MB / 80ms 110 MB / 30ms -0.5%
Llama 7B (RAG) 14 GB / 100 token/s 4 GB / 230 token/s < 1%

移动 / 边缘 / CPU 推理场景量化几乎是必做。

几个陷阱

  1. 量化前 fuse 模块
    python torch.ao.quantization.fuse_modules( model, [['conv', 'bn', 'relu']], inplace=True)
    Conv-BN-ReLU 合成一个 op 后量化效果更好。漏 fuse 精度可能掉 2-5%。

  2. observer 范围错:calibration 数据不代表 production → activation
    范围估计错 → 量化 clip 严重。calibration 一定用真实分布数据。

  3. 某些 layer 不能量化:softmax / layernorm 等保留 FP32。
    model.qconfig = ... 全局设后,对这些 layer 显式 qconfig=None

  4. 不同硬件 backendfbgemm 是 x86 优化,qnnpack 是 ARM 优化。
    部署目标错了性能差 2-5 倍。

  5. 量化后调试难:bug 是模型本身的还是量化引入的?保留 FP32
    reference 模型对比每层 activation 找漂移最大的 layer。

总结

场景 推荐
BERT / transformer post-hoc dynamic INT8
CNN 上 ARM / edge static INT8 + QAT
LLM 推理 bitsandbytes NF4 / AWQ / GGUF
跨平台部署 ONNX Runtime + INT8
极致精度要求 QAT
不想自己折腾 用社区预量化模型

量化是 ML 生产工程的标准动作,不做白扔 70% 推理性能。

精确评价 共 0 人评价
可复现性
可复现 · 0 不可复现 · 0
文风
文风流畅 · 0 文风晦涩 · 0
立场
支持 · 0 反对 · 0

登录后即可对本帖作出评价。

评论区 0 条 · 所有人可在此交流

登录后参与评论。

还没有评论,来说两句。