MNIST 手写数字识别是 ML 的 "hello world"。下面用 PyTorch 写一个从
数据加载到训练 / 评估的完整 pipeline。60 行可读代码,能跑到 99%+ 准确率。
完整代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'using {device}')
# 1. 数据
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=512, shuffle=False, num_workers=4)
# 2. 模型
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
self.drop = nn.Dropout(0.25)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2)) # 28x28 -> 14x14
x = F.relu(F.max_pool2d(self.conv2(x), 2)) # 14x14 -> 7x7
x = x.flatten(1)
x = F.relu(self.fc1(x))
x = self.drop(x)
return self.fc2(x)
model = Net().to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
# 3. 训练
def train_epoch(epoch):
model.train()
total, correct, total_loss = 0, 0, 0.0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
opt.zero_grad()
logits = model(x)
loss = loss_fn(logits, y)
loss.backward()
opt.step()
total_loss += loss.item() * x.size(0)
correct += (logits.argmax(1) == y).sum().item()
total += x.size(0)
print(f'epoch {epoch} train loss {total_loss/total:.4f} acc {correct/total:.4f}')
# 4. 评估
def evaluate():
model.eval()
total, correct = 0, 0
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
correct += (model(x).argmax(1) == y).sum().item()
total += x.size(0)
print(f'test acc {correct/total:.4f}')
for ep in range(5):
train_epoch(ep)
evaluate()
torch.save(model.state_dict(), 'mnist.pt')
5 个 epoch 后测试集准确率应 ≥ 99%。GPU 上每 epoch < 10 秒;
CPU 大约 30-60 秒。
解释几个关键点
Normalize((0.1307,), (0.3081,))
MNIST 训练集统计出的均值和标准差。归一化让输入分布更均匀,
训练更稳。
num_workers=4
DataLoader 用 4 个子进程并行 prefetch 数据。GPU 训练时 IO 是瓶颈,
设大点(4-8)能让 GPU 持续吃满。CPU 训练时设 0 反而更快(避免进程切换)。
opt.zero_grad()
PyTorch 默认 gradient 累加。每次 backward 前清零。
忘了清零 → loss 一直涨。
with torch.no_grad():
评估时不需要梯度,省内存 + 快。等价的还有装饰器 @torch.no_grad()
或 torch.inference_mode()(更激进,PyTorch 1.9+)。
.to(device)
每个 batch 显式搬到 GPU;model 一次性 .to(device)。
加几个常见优化
1. 学习率调度
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=5)
# 每个 epoch 结束后调用
sched.step()
cosine annealing 在前期保持高 lr 探索,后期降低做精调。
2. 混合精度(AMP)
scaler = torch.cuda.amp.GradScaler()
for x, y in train_loader:
x, y = x.to(device), y.to(device)
opt.zero_grad()
with torch.cuda.amp.autocast():
logits = model(x)
loss = loss_fn(logits, y)
scaler.scale(loss).backward()
scaler.step(opt)
scaler.update()
混合精度让大部分计算用 fp16,需要精度的(loss / 权重更新)用 fp32。
RTX 30 系以上 ~2x 加速,显存降一半。MNIST 这种小模型差别不大,
但对 transformer / 大网络明显。
3. 用 torch.compile()(PyTorch 2.0+)
model = torch.compile(model)
编译期内联 + 算子融合,训练 / 推理 1.3-2x 加速。第一次 batch 慢
(编译),之后变快。
4. 早停
best_acc = 0
patience = 3
no_improve = 0
for ep in range(50):
train_epoch(ep)
acc = evaluate()
if acc > best_acc:
best_acc = acc
no_improve = 0
torch.save(model.state_dict(), 'best.pt')
else:
no_improve += 1
if no_improve >= patience:
print('early stop')
break
推理
model = Net().to(device)
model.load_state_dict(torch.load('mnist.pt'))
model.eval()
# 单张图片
from PIL import Image
img = Image.open('test.png').convert('L').resize((28, 28))
x = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
pred = model(x).argmax(1).item()
print(f'predicted: {pred}')
可视化训练曲线
import matplotlib.pyplot as plt
losses = []
accs = []
# ... 在 train_epoch / evaluate 里 append
fig, ax = plt.subplots(1, 2, figsize=(10, 4))
ax[0].plot(losses); ax[0].set_title('train loss')
ax[1].plot(accs); ax[1].set_title('test acc')
plt.savefig('curves.png')
或者用 tensorboard / wandb(参见专题文章)。
踩过的坑
- 第一次跑会下载 MNIST 数据(~10MB),需要网络。
download=True给
你这个能力。 - forward 里
x.view(-1, 784)vsx.flatten(1):view 要求内存连续,
flatten 不要求。前者偶尔报 "view size is not compatible..." 错。 - 不调用
model.train()/model.eval()→ Dropout / BatchNorm 行为错。
eval 期间 Dropout 应该关(输出固定);BatchNorm 应该用 running stats。 torch.cuda.empty_cache()是给 PyTorch 内部 caching allocator 用的,
通常没必要手动调;调了也不会真的还给 OS。担心 OOM 应该减小 batch。
登录后参与评论。