训练一个 MNIST 分类器(PyTorch from scratch,60 行)

MNIST 手写数字识别是 ML 的 "hello world"。下面用 PyTorch 写一个从
数据加载到训练 / 评估的完整 pipeline。60 行可读代码,能跑到 99%+ 准确率。

完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'using {device}')

# 1. 数据
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])
train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=512, shuffle=False, num_workers=4)

# 2. 模型
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.drop = nn.Dropout(0.25)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))   # 28x28 -> 14x14
        x = F.relu(F.max_pool2d(self.conv2(x), 2))   # 14x14 -> 7x7
        x = x.flatten(1)
        x = F.relu(self.fc1(x))
        x = self.drop(x)
        return self.fc2(x)

model = Net().to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

# 3. 训练
def train_epoch(epoch):
    model.train()
    total, correct, total_loss = 0, 0, 0.0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        opt.zero_grad()
        logits = model(x)
        loss = loss_fn(logits, y)
        loss.backward()
        opt.step()
        total_loss += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        total += x.size(0)
    print(f'epoch {epoch}  train  loss {total_loss/total:.4f}  acc {correct/total:.4f}')

# 4. 评估
def evaluate():
    model.eval()
    total, correct = 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            correct += (model(x).argmax(1) == y).sum().item()
            total += x.size(0)
    print(f'test acc {correct/total:.4f}')

for ep in range(5):
    train_epoch(ep)
    evaluate()

torch.save(model.state_dict(), 'mnist.pt')

5 个 epoch 后测试集准确率应 ≥ 99%。GPU 上每 epoch < 10 秒;
CPU 大约 30-60 秒。

解释几个关键点

Normalize((0.1307,), (0.3081,))

MNIST 训练集统计出的均值和标准差。归一化让输入分布更均匀,
训练更稳。

num_workers=4

DataLoader 用 4 个子进程并行 prefetch 数据。GPU 训练时 IO 是瓶颈,
设大点(4-8)能让 GPU 持续吃满。CPU 训练时设 0 反而更快(避免进程切换)。

opt.zero_grad()

PyTorch 默认 gradient 累加。每次 backward 前清零。
忘了清零 → loss 一直涨。

with torch.no_grad():

评估时不需要梯度,省内存 + 快。等价的还有装饰器 @torch.no_grad()
torch.inference_mode()(更激进,PyTorch 1.9+)。

.to(device)

每个 batch 显式搬到 GPU;model 一次性 .to(device)

加几个常见优化

1. 学习率调度

sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=5)
# 每个 epoch 结束后调用
sched.step()

cosine annealing 在前期保持高 lr 探索,后期降低做精调。

2. 混合精度(AMP)

scaler = torch.cuda.amp.GradScaler()

for x, y in train_loader:
    x, y = x.to(device), y.to(device)
    opt.zero_grad()
    with torch.cuda.amp.autocast():
        logits = model(x)
        loss = loss_fn(logits, y)
    scaler.scale(loss).backward()
    scaler.step(opt)
    scaler.update()

混合精度让大部分计算用 fp16,需要精度的(loss / 权重更新)用 fp32。
RTX 30 系以上 ~2x 加速,显存降一半。MNIST 这种小模型差别不大,
但对 transformer / 大网络明显。

3. 用 torch.compile()(PyTorch 2.0+)

model = torch.compile(model)

编译期内联 + 算子融合,训练 / 推理 1.3-2x 加速。第一次 batch 慢
(编译),之后变快。

4. 早停

best_acc = 0
patience = 3
no_improve = 0
for ep in range(50):
    train_epoch(ep)
    acc = evaluate()
    if acc > best_acc:
        best_acc = acc
        no_improve = 0
        torch.save(model.state_dict(), 'best.pt')
    else:
        no_improve += 1
        if no_improve >= patience:
            print('early stop')
            break

推理

model = Net().to(device)
model.load_state_dict(torch.load('mnist.pt'))
model.eval()

# 单张图片
from PIL import Image
img = Image.open('test.png').convert('L').resize((28, 28))
x = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
    pred = model(x).argmax(1).item()
print(f'predicted: {pred}')

可视化训练曲线

import matplotlib.pyplot as plt
losses = []
accs = []
# ... 在 train_epoch / evaluate 里 append

fig, ax = plt.subplots(1, 2, figsize=(10, 4))
ax[0].plot(losses); ax[0].set_title('train loss')
ax[1].plot(accs); ax[1].set_title('test acc')
plt.savefig('curves.png')

或者用 tensorboard / wandb(参见专题文章)。

踩过的坑

  • 第一次跑会下载 MNIST 数据(~10MB),需要网络。download=True
    你这个能力。
  • forward 里 x.view(-1, 784) vs x.flatten(1):view 要求内存连续,
    flatten 不要求。前者偶尔报 "view size is not compatible..." 错。
  • 不调用 model.train() / model.eval() → Dropout / BatchNorm 行为错。
    eval 期间 Dropout 应该关(输出固定);BatchNorm 应该用 running stats。
  • torch.cuda.empty_cache() 是给 PyTorch 内部 caching allocator 用的,
    通常没必要手动调;调了也不会真的还给 OS。担心 OOM 应该减小 batch。
精确评价 共 0 人评价
可复现性
可复现 · 0 不可复现 · 0
文风
文风流畅 · 0 文风晦涩 · 0
立场
支持 · 0 反对 · 0

登录后即可对本帖作出评价。

评论区 0 条 · 所有人可在此交流

登录后参与评论。

还没有评论,来说两句。