PyTorch 训练 OOM 排查:activation checkpoint / 梯度累积 / offload

起因

要 fine-tune 一个 7B 模型,A100 40GB 显存,跑起来直接 CUDA OOM。
"换大卡"是简单解决但贵。理解几个技术能在同样显存里训更大模型 /
更大 batch。

各项的显存占用拆解

训练时显存 ≈ 模型权重 + 梯度 + optimizer state + activations + 临时
buffer。以 7B FP16 模型 + AdamW 为例:

公式 7B 模型
权重 params × 2 bytes (fp16) 14 GB
梯度 params × 2 bytes 14 GB
optimizer state(AdamW) params × 8 bytes (FP32 m+v) 56 GB
activations 依 batch / seq 几 GB-几十 GB

总 = 84 GB + activations。一张 A100 40GB 远不够。

解决方案逐个上

1. 混合精度(FP16/BF16)— 必选

# pure PyTorch
scaler = torch.cuda.amp.GradScaler()
for batch in loader:
    with torch.cuda.amp.autocast(dtype=torch.bfloat16):
        loss = model(batch).loss
    scaler.scale(loss).backward()
    scaler.step(opt); scaler.update()

权重 / 梯度从 FP32 4 bytes → FP16/BF16 2 bytes,对半省。
A100+ 推荐 BF16(无需 grad scaler,数值更稳)。

2. Gradient Checkpointing — 用计算换显存

normal 前向把所有 activations 都存着(反向用)。checkpointing 只保存
某几层,其它 layer 反向时重新算前向:

model.gradient_checkpointing_enable()   # transformers 一行

省 activations 50-80%,代价是训练慢 ~20-30%。LLM fine-tune 默认开。

3. Gradient Accumulation — 模拟更大 batch

显存装不下 batch=32?跑 batch=8 累 4 次 = batch=32 等效:

accum_steps = 4
for i, batch in enumerate(loader):
    loss = model(batch).loss / accum_steps
    loss.backward()
    if (i + 1) % accum_steps == 0:
        opt.step(); opt.zero_grad()

显存等同 batch=8,效果近似 batch=32。

4. CPU offload(DeepSpeed / accelerate)

把 optimizer state 卸到 CPU 内存,反正它不参与每步前反向:

from accelerate import Accelerator

acc = Accelerator(
    mixed_precision='bf16',
    gradient_accumulation_steps=4,
)
model, opt, loader = acc.prepare(model, opt, loader)

或用 DeepSpeed ZeRO-2 / ZeRO-3:

# accelerate config 选 DeepSpeed
# 跑:
accelerate launch --num_processes=1 \
  --mixed_precision=bf16 \
  --deepspeed_stage=2 \
  train.py

ZeRO-2 把 optimizer state 分片(多卡时)/ offload 到 CPU(单卡时),
省 56 GB → 0 GB(cpu 接管)。代价:每 step 数据传输延迟。

5. LoRA / QLoRA — 只训一小部分参数

from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.05,
    target_modules=['q_proj', 'v_proj'],
    task_type='CAUSAL_LM',
)
model = get_peft_model(base_model, config)
model.print_trainable_parameters()
# trainable: 4.2M / 7B = 0.06%

只有 LoRA 的小矩阵需要梯度 + optimizer state。7B 模型变成"7B 冻结
+ 4M 可训",显存暴跌。

QLoRA 进一步把 base model 也量化到 4-bit:

from transformers import BitsAndBytesConfig

bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(name, quantization_config=bnb)
model = get_peft_model(model, config)

7B QLoRA 单 A100 40GB 训 batch=4 可以跑得动。

效果(我的 case:A100 40GB fine-tune Qwen2 7B)

配置 显存 训练速度 效果损失
FP32 full ft OOM
BF16 full ft OOM (~80 GB)
BF16 + grad checkpoint OOM (~50 GB)
BF16 + checkpoint + ZeRO-2 cpu offload 32 GB 1x 0
BF16 + LoRA 24 GB 1.3x 微小
BF16 + QLoRA 14 GB 1.2x 1-2%

最终 QLoRA 跑通 fine-tune,loss 收敛、benchmark 比 base 提升 8%。

调试技巧

# 看每层显存
print(torch.cuda.memory_summary())

# 最大峰值
print(f'peak: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB')
torch.cuda.reset_peak_memory_stats()

# 实时监控
nvidia-smi -l 1
# 或更细:
nvtop

跑 OOM 时立刻 nvidia-smi 看到底是 model load 时挂了还是 forward 时挂了,
对症下药。

踩过的坑

  1. del var 不立刻释放:PyTorch caching allocator 不还给 OS。
    torch.cuda.empty_cache() 也只是把 cached block 让出来,不会
    实际减少 OS 看到的进程显存。

  2. DataLoader pin_memory + num_workers 大:每个 worker 一份 GPU
    显存映射。OOM 时先减 num_workers

  3. eval 不开 no_grad:评估时没 with torch.no_grad():,accidentally
    build 完整 computation graph,显存翻倍。

  4. 多个模型同时 load:base model + LoRA + reward model 一起在 GPU
    上时,DPO / RLHF 训练显存压力极大。把 reward model 量化或 freeze
    后丢 CPU。

  5. 使用 compute_dtype=torch.float16 + Adam:fp16 + Adam 数值
    不稳定。一律 bf16 或者 fp32 master weight(mixed precision)。

精确评价 共 0 人评价
可复现性
可复现 · 0 不可复现 · 0
文风
文风流畅 · 0 文风晦涩 · 0
立场
支持 · 0 反对 · 0

登录后即可对本帖作出评价。

评论区 0 条 · 所有人可在此交流

登录后参与评论。

还没有评论,来说两句。