用 Weights & Biases (wandb) 跟踪 ML 实验(替代手抄表格)

跑 ML 实验最容易混乱的是"我那个 lr=0.001 + dropout=0.3 的 run 是哪天哪份代码?"
wandb 自动记录:超参、loss / metric 曲线、代码 git hash、系统资源、模型权重。
免费层个人项目无限。

安装 + 注册

uv add wandb
wandb login   # 浏览器打开,复制 API key 进来
# 或:export WANDB_API_KEY=...

最小集成(5 行代码)

import wandb

wandb.init(
    project='mnist-cnn',
    config={
        'lr': 0.001,
        'batch_size': 128,
        'epochs': 5,
        'dropout': 0.25,
        'model': 'cnn-v1',
    },
)

# 训练循环里
for epoch in range(5):
    train_loss, train_acc = train_epoch()
    val_loss, val_acc = evaluate()
    wandb.log({
        'train/loss': train_loss,
        'train/acc': train_acc,
        'val/loss': val_loss,
        'val/acc': val_acc,
        'epoch': epoch,
    })

wandb.finish()

跑一下 python train.py,wandb 打印一个 URL,打开就能看到实时曲线。

config 优先级

# 1. 默认在代码里
wandb.init(config={'lr': 0.001})

# 2. 命令行覆盖(用 argparse / typer / hydra)
import argparse
p = argparse.ArgumentParser()
p.add_argument('--lr', type=float, default=0.001)
args = p.parse_args()
wandb.init(config=vars(args))

# 3. Sweep 时由 wandb 注入
config = wandb.config   # 读,不能写
lr = config.lr

Sweep:自动超参搜索

# sweep.yaml
program: train.py
method: bayes
metric:
  name: val/acc
  goal: maximize
parameters:
  lr:
    distribution: log_uniform_values
    min: 1e-5
    max: 1e-2
  dropout:
    values: [0.1, 0.25, 0.5]
  batch_size:
    values: [64, 128, 256]
wandb sweep sweep.yaml
# Output: wandb agent yourname/project/sweep_id
wandb agent yourname/project/sweep_id --count 20

跑 20 个实验,自动按 Bayesian / grid / random 选超参。
多台机器并行:每台跑 wandb agent ... 共享同一个 sweep。

记录媒体

# 图片
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(losses)
wandb.log({'curve': wandb.Image(fig)})

# 表格
wandb.log({
    'predictions': wandb.Table(
        columns=['image', 'pred', 'truth'],
        data=[[wandb.Image(x), p, y] for x, p, y in samples],
    )
})

# 直方图
wandb.log({'grad_norm': wandb.Histogram(grad_norms)})

# 视频 / 音频
wandb.log({'video': wandb.Video('output.mp4')})

保存模型(Artifact)

artifact = wandb.Artifact(name='mnist-cnn', type='model')
artifact.add_file('checkpoints/best.pt')
wandb.log_artifact(artifact)

模型权重 + 元数据存在 wandb 服务端(免费层有配额)。后续加载:

api = wandb.Api()
artifact = api.artifact('yourname/project/mnist-cnn:latest')
artifact.download(root='./model')

监控系统资源

wandb 自动记录:

  • GPU 利用率 / 显存
  • CPU / 内存
  • 磁盘 / 网络
  • Python 进程

无需任何代码,看 dashboard 的 "System" 标签页。

代码 + git 状态

wandb 自动 capture:

  • 当前 git commit hash
  • 未提交的 diff(不要在没干净 commit 时跑实验!)
  • 命令行参数
  • Python 版本 + 包列表

回放某次 run:知道用了哪份代码 + 哪个数据。

离线模式

无网时:

wandb offline
# 或:export WANDB_MODE=offline
python train.py
# 数据存在本地 ./wandb/

# 之后有网时
wandb sync ./wandb/offline-run-*

CI / 内网集群里非常有用。

与 PyTorch Lightning / HuggingFace 集成

# PyTorch Lightning
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(project='myproj')
trainer = Trainer(logger=logger)

# HuggingFace Transformers
training_args = TrainingArguments(
    report_to='wandb',
    run_name='bert-finetune-v3',
    ...
)

不用手写 log,框架自动调 wandb。

团队协作

  • Project 是大目录(按代号 / 任务)
  • Run 是单次实验
  • Group 让你把同一个 sweep / 同一组对比放一起
  • Tags 给 run 打标签(baseline / experiment / final)
wandb.init(
    project='myproj',
    group='ablation-dropout',
    tags=['final', 'cnn-v2'],
    notes='increased lr to 5e-3 to test convergence speed',
)

数据 dashboard

dashboard 默认按时间线显示。常用功能:

  • Reports:把多个 run 拖到一份"报告"里,加文字 + 自动同步图表,
    当作"项目周报"或 paper 草稿
  • Parallel Coordinates:可视化超参 → metric 的关系,
    找哪个超参影响最大
  • Compare runs:勾几个 run 一起看曲线 / config diff

隐私 / 自托管

wandb cloud 免费个人无限,团队收费。如果数据敏感不能上 cloud:

# 自托管 wandb server(社区版免费)
docker run -p 8080:8080 wandb/local

Python 端:

wandb.init(project='...', host='https://wandb.your-company.com')

替代方案

  • MLflow:开源,自托管简单。tracking + model registry,UI 朴素
  • TensorBoard:本地用,无云端
  • Comet / Neptune:商业产品类似
  • Aim:开源极简,无云

个人项目 wandb 最快;公司里数据合规要求高用 MLflow。

踩过的坑

  • wandb.finish() —— 长 run 退出后 wandb sync 一直挂着。
    脚本最后必须 finish() 或者用 with wandb.init(...) as run:
  • 每次 wandb.log 都立刻发到云端 → 训练时大量请求可能拖慢。
    commit=False 累积后批量发:wandb.log({...}, commit=False)
    最后 wandb.log({}, commit=True)
  • 在 jupyter 里 wandb.init 多次:会创建多个 run。每次重启 kernel
    之前先 wandb.finish()
  • artifact 配额:免费个人 100GB,超了不能再传。定期清老 artifact。
精确评价 共 0 人评价
可复现性
可复现 · 0 不可复现 · 0
文风
文风流畅 · 0 文风晦涩 · 0
立场
支持 · 0 反对 · 0

登录后即可对本帖作出评价。

评论区 0 条 · 所有人可在此交流

登录后参与评论。

还没有评论,来说两句。