跑 ML 实验最容易混乱的是"我那个 lr=0.001 + dropout=0.3 的 run 是哪天哪份代码?"
wandb 自动记录:超参、loss / metric 曲线、代码 git hash、系统资源、模型权重。
免费层个人项目无限。
安装 + 注册
uv add wandb
wandb login # 浏览器打开,复制 API key 进来
# 或:export WANDB_API_KEY=...
最小集成(5 行代码)
import wandb
wandb.init(
project='mnist-cnn',
config={
'lr': 0.001,
'batch_size': 128,
'epochs': 5,
'dropout': 0.25,
'model': 'cnn-v1',
},
)
# 训练循环里
for epoch in range(5):
train_loss, train_acc = train_epoch()
val_loss, val_acc = evaluate()
wandb.log({
'train/loss': train_loss,
'train/acc': train_acc,
'val/loss': val_loss,
'val/acc': val_acc,
'epoch': epoch,
})
wandb.finish()
跑一下 python train.py,wandb 打印一个 URL,打开就能看到实时曲线。
config 优先级
# 1. 默认在代码里
wandb.init(config={'lr': 0.001})
# 2. 命令行覆盖(用 argparse / typer / hydra)
import argparse
p = argparse.ArgumentParser()
p.add_argument('--lr', type=float, default=0.001)
args = p.parse_args()
wandb.init(config=vars(args))
# 3. Sweep 时由 wandb 注入
config = wandb.config # 读,不能写
lr = config.lr
Sweep:自动超参搜索
# sweep.yaml
program: train.py
method: bayes
metric:
name: val/acc
goal: maximize
parameters:
lr:
distribution: log_uniform_values
min: 1e-5
max: 1e-2
dropout:
values: [0.1, 0.25, 0.5]
batch_size:
values: [64, 128, 256]
wandb sweep sweep.yaml
# Output: wandb agent yourname/project/sweep_id
wandb agent yourname/project/sweep_id --count 20
跑 20 个实验,自动按 Bayesian / grid / random 选超参。
多台机器并行:每台跑 wandb agent ... 共享同一个 sweep。
记录媒体
# 图片
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(losses)
wandb.log({'curve': wandb.Image(fig)})
# 表格
wandb.log({
'predictions': wandb.Table(
columns=['image', 'pred', 'truth'],
data=[[wandb.Image(x), p, y] for x, p, y in samples],
)
})
# 直方图
wandb.log({'grad_norm': wandb.Histogram(grad_norms)})
# 视频 / 音频
wandb.log({'video': wandb.Video('output.mp4')})
保存模型(Artifact)
artifact = wandb.Artifact(name='mnist-cnn', type='model')
artifact.add_file('checkpoints/best.pt')
wandb.log_artifact(artifact)
模型权重 + 元数据存在 wandb 服务端(免费层有配额)。后续加载:
api = wandb.Api()
artifact = api.artifact('yourname/project/mnist-cnn:latest')
artifact.download(root='./model')
监控系统资源
wandb 自动记录:
- GPU 利用率 / 显存
- CPU / 内存
- 磁盘 / 网络
- Python 进程
无需任何代码,看 dashboard 的 "System" 标签页。
代码 + git 状态
wandb 自动 capture:
- 当前 git commit hash
- 未提交的 diff(不要在没干净 commit 时跑实验!)
- 命令行参数
- Python 版本 + 包列表
回放某次 run:知道用了哪份代码 + 哪个数据。
离线模式
无网时:
wandb offline
# 或:export WANDB_MODE=offline
python train.py
# 数据存在本地 ./wandb/
# 之后有网时
wandb sync ./wandb/offline-run-*
CI / 内网集群里非常有用。
与 PyTorch Lightning / HuggingFace 集成
# PyTorch Lightning
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(project='myproj')
trainer = Trainer(logger=logger)
# HuggingFace Transformers
training_args = TrainingArguments(
report_to='wandb',
run_name='bert-finetune-v3',
...
)
不用手写 log,框架自动调 wandb。
团队协作
- Project 是大目录(按代号 / 任务)
- Run 是单次实验
- Group 让你把同一个 sweep / 同一组对比放一起
- Tags 给 run 打标签(baseline / experiment / final)
wandb.init(
project='myproj',
group='ablation-dropout',
tags=['final', 'cnn-v2'],
notes='increased lr to 5e-3 to test convergence speed',
)
数据 dashboard
dashboard 默认按时间线显示。常用功能:
- Reports:把多个 run 拖到一份"报告"里,加文字 + 自动同步图表,
当作"项目周报"或 paper 草稿 - Parallel Coordinates:可视化超参 → metric 的关系,
找哪个超参影响最大 - Compare runs:勾几个 run 一起看曲线 / config diff
隐私 / 自托管
wandb cloud 免费个人无限,团队收费。如果数据敏感不能上 cloud:
# 自托管 wandb server(社区版免费)
docker run -p 8080:8080 wandb/local
Python 端:
wandb.init(project='...', host='https://wandb.your-company.com')
替代方案
- MLflow:开源,自托管简单。tracking + model registry,UI 朴素
- TensorBoard:本地用,无云端
- Comet / Neptune:商业产品类似
- Aim:开源极简,无云
个人项目 wandb 最快;公司里数据合规要求高用 MLflow。
踩过的坑
- 忘
wandb.finish()—— 长 run 退出后 wandb sync 一直挂着。
脚本最后必须finish()或者用with wandb.init(...) as run:。 - 每次
wandb.log都立刻发到云端 → 训练时大量请求可能拖慢。
设commit=False累积后批量发:wandb.log({...}, commit=False),
最后wandb.log({}, commit=True)。 - 在 jupyter 里 wandb.init 多次:会创建多个 run。每次重启 kernel
之前先wandb.finish()。 - artifact 配额:免费个人 100GB,超了不能再传。定期清老 artifact。
登录后参与评论。