起因
公司数据不能上 wandb / Comet 这种 cloud SaaS。要本地自托管的实验
跟踪 + 模型版本控制 + 一键部署。
MLflow 是 Databricks 出的开源套件,4 个组件覆盖 ML lifecycle:
- Tracking:记录实验(params / metrics / artifacts)
- Projects:可复现的 ML 包格式
- Models:模型版本 + 多 framework 统一接口
- Registry:模型生命周期(staging / production / archived)
装 + 启服务
uv add mlflow
# 启动 tracking server(默认 SQLite 后端 + 本地文件 artifact)
mlflow server \
--host 0.0.0.0 --port 5000 \
--backend-store-uri sqlite:///mlflow.db \
--default-artifact-root ./mlruns
或更"生产"配置(PostgreSQL + S3):
mlflow server \
--host 0.0.0.0 --port 5000 \
--backend-store-uri postgresql://user:pass@db/mlflow \
--default-artifact-root s3://my-bucket/mlruns \
--workers 4
systemd unit + nginx 套一下就是企业级服务。
Tracking:训练时记录
import mlflow
import mlflow.pytorch
mlflow.set_tracking_uri('http://localhost:5000')
mlflow.set_experiment('churn-prediction')
with mlflow.start_run(run_name='lgbm-baseline'):
mlflow.log_params({
'model': 'lgbm',
'lr': 0.05,
'n_estimators': 200,
'max_depth': 7,
})
mlflow.set_tag('dataset_version', 'v2024-05-01')
# 训练
model = train(...)
eval_metrics = evaluate(model, X_val, y_val)
mlflow.log_metrics(eval_metrics)
# {'auc': 0.84, 'precision': 0.71, 'recall': 0.66}
# 多 step:每 epoch log
for epoch in range(20):
train_one_epoch()
mlflow.log_metric('train/loss', loss, step=epoch)
mlflow.log_metric('val/auc', auc, step=epoch)
# 保存模型(mlflow 自动 capture 依赖 env)
mlflow.lightgbm.log_model(model, 'model')
# 任意 artifact
mlflow.log_artifact('confusion_matrix.png')
mlflow.log_artifact('feature_importance.csv')
跑完 → MLflow UI 看到 run,有 params / metrics 表 + 曲线 + artifacts
下载。
实验对比
UI 选多个 run → Compare → 表格 + parallel coordinates + scatter plot。
一眼看出"哪几个超参组合 auc 高"。
Model Registry
# 训练完成后注册到 registry
mlflow.lightgbm.log_model(
lgb_model=model,
artifact_path='model',
registered_model_name='ChurnPredictor',
)
UI 里看到 ChurnPredictor v1。
版本管理 + 状态机:
client = mlflow.MlflowClient()
# 升级到 staging
client.transition_model_version_stage(
name='ChurnPredictor', version=1, stage='Staging',
)
# 验证后升级到 production
client.transition_model_version_stage(
name='ChurnPredictor', version=1, stage='Production',
archive_existing_versions=True, # 老 production 自动 archive
)
线上代码总是拿 production 版本:
model = mlflow.pyfunc.load_model(
model_uri='models:/ChurnPredictor/Production'
)
prediction = model.predict(X)
回滚?把老版本 transition 回 Production 即可。
Models:统一 framework 接口
# sklearn
mlflow.sklearn.log_model(model, 'm')
# PyTorch
mlflow.pytorch.log_model(model, 'm')
# Transformers
mlflow.transformers.log_model(pipeline, 'm')
# 自定义(Pyfunc)
class MyModel(mlflow.pyfunc.PythonModel):
def load_context(self, context):
self.artifact = load(context.artifacts['my_file'])
def predict(self, context, input_df):
return my_inference(input_df, self.artifact)
mlflow.pyfunc.log_model('m', python_model=MyModel(),
artifacts={'my_file': 'data.pkl'})
加载时不需要知道 framework:
model = mlflow.pyfunc.load_model('models:/X/Production')
model.predict(df)
业务代码不再 hardcode "import torch / sklearn / xgboost"。
部署:一行 serve
mlflow models serve -m models:/ChurnPredictor/Production -p 5001
# 起一个 HTTP API on :5001
curl -X POST http://localhost:5001/invocations \
-H 'Content-Type: application/json' \
-d '{"dataframe_records": [{"age": 30, "balance": 1000}]}'
# {"predictions": [0.72]}
或 build Docker image:
mlflow models build-docker -m models:/X/Production -n my-model:latest
docker run -p 5001:8080 my-model:latest
或部署到 Sagemaker / Azure ML / K8s:
mlflow sagemaker deploy ...
mlflow azureml deploy ...
framework 抽象一直延伸到部署。
Autologging
mlflow.sklearn.autolog()
# 之后所有 sklearn fit() 自动 log model + params + metrics
model = RandomForestClassifier(n_estimators=100)
model.fit(X, y)
# 自动 log: n_estimators / max_depth / mean_score / training_time / ...
支持 sklearn / PyTorch / Lightning / TensorFlow / XGBoost / LightGBM。
适合"快速 baseline 跑 N 个 algorithm 选最好的"。
Projects:可复现 ML 包
MLproject 文件:
name: churn-prediction
python_env: python_env.yaml
entry_points:
main:
parameters:
data_path: {type: string, default: 'data/train.parquet'}
lr: {type: float, default: 0.05}
n_estimators: {type: int, default: 200}
command: 'python train.py --data {data_path} --lr {lr} --n {n_estimators}'
跑:
mlflow run . -P lr=0.1 -P n_estimators=500
# 或 git URL:
mlflow run https://github.com/me/churn-prediction.git -P lr=0.1
自动建 virtualenv + 装依赖 + 跑训练。同事任何人能复现。
与替代品对比
| MLflow | Weights & Biases | Neptune | DVC | |
|---|---|---|---|---|
| 自托管 | ✅ 简单 | 企业版 | ✅ | ✅ |
| 实验跟踪 | ✅ | ✅ 最强 | ✅ | 较弱 |
| 模型注册 | ✅ | ✅ | ✅ | ❌ |
| 部署 | ✅ 内置 | ❌ | ❌ | ❌ |
| Pipeline | ❌(外部 Airflow) | weave | ❌ | ✅ |
| 价格 | 免费 | 付费(个人免费) | 付费 | 免费 |
数据合规要本地的 → MLflow。
体验最好 + 不担心数据外发 → wandb。
我们的实战配置
docker-compose.yml:
services:
mlflow:
image: ghcr.io/mlflow/mlflow:v2.16.2
ports: ["5000:5000"]
environment:
MLFLOW_S3_ENDPOINT_URL: http://minio:9000
AWS_ACCESS_KEY_ID: minio
AWS_SECRET_ACCESS_KEY: ...
command: >
mlflow server
--host 0.0.0.0 --port 5000
--backend-store-uri postgresql://mlflow:pw@pg/mlflow
--default-artifact-root s3://mlflow-artifacts/
pg:
image: postgres:16
environment:
POSTGRES_DB: mlflow
POSTGRES_USER: mlflow
POSTGRES_PASSWORD: pw
minio:
image: minio/minio
command: server /data --console-address ":9001"
ports: ["9000:9000", "9001:9001"]
environment:
MINIO_ROOT_USER: minio
MINIO_ROOT_PASSWORD: ...
三个 service:MLflow + Postgres (metadata) + Minio (S3 兼容 artifact)。
全本地,数据不出公司。
效果
- 30+ 个实验跑下来"哪个 lr × dataset 最优"清晰
- 上线模型有版本号 + git commit 关联 + 训练数据版本 traceable
- 回滚到上版本:UI 点 button + 30 秒 redeploy
- 数据科学家 / ML 工程师 / 部署运维各看 UI 不同部分
踩过的坑
-
artifact 上传慢:local file backend OK;S3 时大模型几 GB 上传
几分钟。mlflow.log_artifact阻塞 train script。后台 thread 异步。 -
autolog 误 log 整张 dataframe:默认 sklearn autolog 会 log
X_train shape + 部分 sample。私密数据可能进 MLflow → 安全隐患。
mlflow.sklearn.autolog(log_input_examples=False)。 -
Model registry 没强制 stage gating:任何人能把 dev model 推
Production。生产建 ACL + reviewer 流程。 -
跨 Python 版本 model 加载失败:在 3.10 训练的 sklearn model
在 3.12 load 时 unpickle 错。MLflow log model 时 capture 了
python_env.yaml,确认部署机器装对版本。 -
UI 慢:实验数量 > 几万后 list 慢。定期 archive 老 experiment
到s3://archived/。
登录后参与评论。