PyTorch Lightning 把训练循环写成 50 行(多卡 / 混合精度 / checkpoint 全免费)

起因

裸 PyTorch 写一个像样的训练循环要处理:device 切换 / .train()/.eval() /
gradient zero / loss accumulation / lr scheduler step / checkpoint 保存
/ early stopping / 多 GPU DDP 启动 / 混合精度 / logging。
一个研究 notebook 反复抄这些代码很烦,还容易写错(忘 optimizer.zero_grad()
或者 .eval() 是经典的)。

PyTorch Lightning 把"工程脚手架"和"模型逻辑"分开:你只写 training_step
/ validation_step / configure_optimizers,其它由 framework 处理。

解决方案

uv add lightning torchvision
import lightning as L
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class MNIST(L.LightningModule):
    def __init__(self, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, 3, padding=1), torch.nn.ReLU(),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(32, 64, 3, padding=1), torch.nn.ReLU(),
            torch.nn.MaxPool2d(2),
            torch.nn.Flatten(),
            torch.nn.Linear(64*7*7, 128), torch.nn.ReLU(),
            torch.nn.Linear(128, 10),
        )

    def forward(self, x):
        return self.conv(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = (logits.argmax(1) == y).float().mean()
        self.log_dict({'train/loss': loss, 'train/acc': acc}, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = (logits.argmax(1) == y).float().mean()
        self.log_dict({'val/loss': loss, 'val/acc': acc}, prog_bar=True)

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
        return [opt], [sched]


def main():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
    train = DataLoader(datasets.MNIST('./data', train=True, download=True, transform=transform),
                       batch_size=128, num_workers=4, shuffle=True)
    val = DataLoader(datasets.MNIST('./data', train=False, transform=transform),
                     batch_size=512, num_workers=4)

    trainer = L.Trainer(
        max_epochs=5,
        accelerator='auto',           # cuda / mps / cpu 自动选
        devices='auto',                # 多卡时自动用全部
        precision='16-mixed',          # 混合精度
        callbacks=[
            L.pytorch.callbacks.EarlyStopping(monitor='val/loss', patience=3),
            L.pytorch.callbacks.ModelCheckpoint(monitor='val/acc', mode='max',
                                                 save_top_k=2),
        ],
        logger=L.pytorch.loggers.TensorBoardLogger('logs', name='mnist'),
    )
    trainer.fit(MNIST(), train, val)


if __name__ == '__main__':
    main()

50 行包含训练 + 验证 + 多卡 + 混合精度 + 早停 + checkpoint + TensorBoard。

效果

  • 单 GPU vs 4 GPU DDP 只改 devices=4,无需写 init_process_group 之类
  • 混合精度只改 precision='16-mixed',速度 ~1.5-2x,显存减 30%
  • TensorBoard tensorboard --logdir logs 看 loss / metric 曲线
  • ModelCheckpoint 自动保存 val/acc 最高的 2 个 checkpoint
  • 中断后 Trainer(resume_from_checkpoint=...) 恢复
  • 切换 wandb logger 一行:L.pytorch.loggers.WandbLogger(project=...)

踩过的坑

  1. self.log 不能放在 forward:只能在 step 方法里。否则
    batch_size 信息不对。

  2. DDP 时 validation_step 写了 print → 多进程刷屏。用
    self.logrank_zero_only 装饰。

  3. num_workers > 0 时 macOS / Windows 死锁:DataLoader 用 fork
    策略。Mac 上设 num_workers=0persistent_workers=True

  4. 混合精度 NaN:loss scale 不对。precision='bf16-mixed'(A100+)
    16-mixed 更稳,不需要 grad scaler。

  5. 保存的 checkpoint 太大:默认保存 optimizer state。要 inference-only:
    ModelCheckpoint(save_weights_only=True)

精确评价 共 0 人评价
可复现性
可复现 · 0 不可复现 · 0
文风
文风流畅 · 0 文风晦涩 · 0
立场
支持 · 0 反对 · 0

登录后即可对本帖作出评价。

评论区 0 条 · 所有人可在此交流

登录后参与评论。

还没有评论,来说两句。