起因
训练 ML model 时:
- 改 hyperparameter / feature → 跑一次
- 比较哪次效果好?凭记忆?csv 抄结果?
- model artifact 存哪?git LFS?S3 哪个 path?
- 模型 deploy 时哪个版本?
MLflow 是 Databricks 出的 open source,解决 ML 实验管理 4 件事:
- Tracking:每次 run 记 params / metrics / artifact
- Projects:reproducible run(conda env / docker)
- Models:标准化打包 / 注册 / 部署
- Registry:model version 管理 (staging / production)
装
pip install mlflow
mlflow server --host 0.0.0.0 --port 5000 \
--backend-store-uri sqlite:///mlflow.db \
--default-artifact-root file:./mlruns
UI:http://localhost:5000。
生产用 PG + S3 替代 SQLite + local file。
追踪 run
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score
mlflow.set_tracking_uri('http://localhost:5000')
mlflow.set_experiment('user-churn-prediction')
with mlflow.start_run(run_name='rf-baseline'):
# log hyperparameter
n_estimators = 100
max_depth = 10
mlflow.log_param('n_estimators', n_estimators)
mlflow.log_param('max_depth', max_depth)
# 训练
model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth)
model.fit(X_train, y_train)
# log metric
pred = model.predict(X_val)
mlflow.log_metric('val_accuracy', accuracy_score(y_val, pred))
mlflow.log_metric('val_f1', f1_score(y_val, pred))
# log model artifact
mlflow.sklearn.log_model(model, 'model')
# log 任意 artifact
mlflow.log_artifact('feature_importance.png')
UI 里看:每次 run 一行,columns 是 params + metrics,能排序 / 过滤。
auto-log
不想手动 log 每个 param?
mlflow.sklearn.autolog() # sklearn 自动 log
mlflow.pytorch.autolog() # pytorch
mlflow.xgboost.autolog()
mlflow.tensorflow.autolog()
model = LogisticRegression(C=0.1)
model.fit(X, y) # 自动 log C / penalty / accuracy / model
90% 用例 autolog 够。
比较 run
UI 选多个 run → Compare → 表格 + parallel coordinates 看 hyperparam ↔ metric。
Programmatic:
from mlflow.tracking import MlflowClient
client = MlflowClient()
exp = client.get_experiment_by_name('user-churn-prediction')
# 找最佳 run
runs = client.search_runs(
experiment_ids=[exp.experiment_id],
order_by=['metrics.val_f1 DESC'],
max_results=1,
)
best = runs[0]
print(best.data.params, best.data.metrics)
model registry
train 完想正式 deploy:
result = mlflow.register_model(
f"runs:/{best.info.run_id}/model",
"user-churn-model", # 注册名
)
# version 自动 +1
UI 里看 model registry → "user-churn-model" v1, v2, v3 ...
mark stage:
client.transition_model_version_stage(
name='user-churn-model',
version=3,
stage='Production',
)
production code 加载:
model = mlflow.pyfunc.load_model('models:/user-churn-model/Production')
pred = model.predict(X_test)
切换版本只改 stage 标记,code 不动。
hyperparameter sweep
from itertools import product
for n, d in product([100, 200, 500], [5, 10, 20]):
with mlflow.start_run():
mlflow.log_params({'n_estimators': n, 'max_depth': d})
model = RandomForestClassifier(n_estimators=n, max_depth=d)
model.fit(X_train, y_train)
score = model.score(X_val, y_val)
mlflow.log_metric('val_score', score)
9 run 自动跑 + 全部对比。
配 Optuna / Hyperopt 自动化 search 更强。
跟 git 集成
mlflow 自动 log:
- git commit hash
- branch name
- diff(未 commit 改动)
每个 run 知道是哪个 code 版本产生的 → 复现性。
与替代品对比
| MLflow | Weights & Biases | Neptune | TensorBoard | |
|---|---|---|---|---|
| 自托管 | ✅ | ❌(cloud only OSS有限) | ❌ | ✅(无 server) |
| metric tracking | ✅ | ✅+ | ✅+ | ✅ |
| model registry | ✅ | ✅ | ✅ | ❌ |
| collab | 弱 | 强 | 中 | 弱 |
| 成本 | 0 | 团队收费 | 团队收费 | 0 |
| 生态 | 大 | 大 | 中 | 大 |
我个人 / 小团队 → MLflow(自托管 + 免费 + 标准)。
大团队 / 跨公司 collab → W&B。
部署:自托管设置
# docker-compose.yml
services:
mlflow:
image: ghcr.io/mlflow/mlflow:v2.13
ports:
- 5000:5000
environment:
- AWS_ACCESS_KEY_ID=...
- AWS_SECRET_ACCESS_KEY=...
command: >
mlflow server
--host 0.0.0.0
--backend-store-uri postgresql://user:pass@db/mlflow
--default-artifact-root s3://my-bucket/mlflow
depends_on:
- db
db:
image: postgres:16
environment:
POSTGRES_USER: user
POSTGRES_PASSWORD: pass
POSTGRES_DB: mlflow
volumes:
- mlflow-pg:/var/lib/postgresql/data
PG 存 metadata + S3 存 artifact → 可 scale。
实战 lessons
我们一个客户 churn 项目用 MLflow:
- 200+ experiment run(不同 feature set / model type / hyperparam)
- 8 个最终 model 在 registry
- Production model 每月更新(registry stage transition)
- 任何时候能 diff 当前 production 跟 candidate
少了 mlflow 之前:
- experiment 结果存团队 Notion / Slack message
- model artifact 各种 S3 path 散
- 谁也不知道 production 现在跑的是 train_v3_final_FINAL2.pkl 还是 v4
接入 mlflow 后:
- experiment 透明 / 可追溯
- model 版本明确
- 切回老版本 1 行命令
跟 Airflow / Prefect 集成
ML pipeline DAG 每天跑:
@task
def train_and_register():
with mlflow.start_run():
model = train(...)
mlflow.sklearn.log_model(model, 'model')
if score > threshold:
mlflow.register_model(...)
定期 retrain → log → 满足条件 promote 到 staging → 人手批准 →
production。CD for ML。
踩过的坑
-
artifact path 写错:local file mode 测试 OK,部署 server 后
./mlruns在 server 找不到 → 必须 S3 / blob storage。 -
大 model artifact:log 大 model 几 GB 慢。考虑只 log 必要部分。
-
run 不 close:
mlflow.start_run()不在 with block + 没
end_run()→ status="RUNNING" 一直挂。with习惯保命。 -
registry 名字冲突:team 多人用同一 mlflow server,注册名建议
带 prefix 或 namespace。 -
MLflow autolog 与 keras 不全兼容:某些 callback / model
subclassing 不 capture。complex case 手动 log。
登录后参与评论。