起因
裸 PyTorch 写一个像样的训练循环要处理:device 切换 / .train()/.eval() /
gradient zero / loss accumulation / lr scheduler step / checkpoint 保存
/ early stopping / 多 GPU DDP 启动 / 混合精度 / logging。
一个研究 notebook 反复抄这些代码很烦,还容易写错(忘 optimizer.zero_grad()
或者 .eval() 是经典的)。
PyTorch Lightning 把"工程脚手架"和"模型逻辑"分开:你只写 training_step
/ validation_step / configure_optimizers,其它由 framework 处理。
解决方案
uv add lightning torchvision
import lightning as L
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
class MNIST(L.LightningModule):
def __init__(self, lr=1e-3):
super().__init__()
self.save_hyperparameters()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, 32, 3, padding=1), torch.nn.ReLU(),
torch.nn.MaxPool2d(2),
torch.nn.Conv2d(32, 64, 3, padding=1), torch.nn.ReLU(),
torch.nn.MaxPool2d(2),
torch.nn.Flatten(),
torch.nn.Linear(64*7*7, 128), torch.nn.ReLU(),
torch.nn.Linear(128, 10),
)
def forward(self, x):
return self.conv(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
acc = (logits.argmax(1) == y).float().mean()
self.log_dict({'train/loss': loss, 'train/acc': acc}, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
acc = (logits.argmax(1) == y).float().mean()
self.log_dict({'val/loss': loss, 'val/acc': acc}, prog_bar=True)
def configure_optimizers(self):
opt = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
return [opt], [sched]
def main():
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
train = DataLoader(datasets.MNIST('./data', train=True, download=True, transform=transform),
batch_size=128, num_workers=4, shuffle=True)
val = DataLoader(datasets.MNIST('./data', train=False, transform=transform),
batch_size=512, num_workers=4)
trainer = L.Trainer(
max_epochs=5,
accelerator='auto', # cuda / mps / cpu 自动选
devices='auto', # 多卡时自动用全部
precision='16-mixed', # 混合精度
callbacks=[
L.pytorch.callbacks.EarlyStopping(monitor='val/loss', patience=3),
L.pytorch.callbacks.ModelCheckpoint(monitor='val/acc', mode='max',
save_top_k=2),
],
logger=L.pytorch.loggers.TensorBoardLogger('logs', name='mnist'),
)
trainer.fit(MNIST(), train, val)
if __name__ == '__main__':
main()
50 行包含训练 + 验证 + 多卡 + 混合精度 + 早停 + checkpoint + TensorBoard。
效果
- 单 GPU vs 4 GPU DDP 只改
devices=4,无需写 init_process_group 之类 - 混合精度只改
precision='16-mixed',速度 ~1.5-2x,显存减 30% - TensorBoard
tensorboard --logdir logs看 loss / metric 曲线 - ModelCheckpoint 自动保存 val/acc 最高的 2 个 checkpoint
- 中断后
Trainer(resume_from_checkpoint=...)恢复 - 切换 wandb logger 一行:
L.pytorch.loggers.WandbLogger(project=...)
踩过的坑
-
self.log不能放在forward里:只能在 step 方法里。否则
batch_size 信息不对。 -
DDP 时
validation_step写了 print → 多进程刷屏。用
self.log或rank_zero_only装饰。 -
num_workers > 0时 macOS / Windows 死锁:DataLoader 用 fork
策略。Mac 上设num_workers=0或persistent_workers=True。 -
混合精度 NaN:loss scale 不对。
precision='bf16-mixed'(A100+)
比16-mixed更稳,不需要 grad scaler。 -
保存的 checkpoint 太大:默认保存 optimizer state。要 inference-only:
ModelCheckpoint(save_weights_only=True)。
登录后参与评论。