MLflow:本地自托管的实验跟踪 + 模型注册 + 部署 4-in-1

起因

公司数据不能上 wandb / Comet 这种 cloud SaaS。要本地自托管的实验
跟踪 + 模型版本控制 + 一键部署。
MLflow 是 Databricks 出的开源套件,4 个组件覆盖 ML lifecycle:

  • Tracking:记录实验(params / metrics / artifacts)
  • Projects:可复现的 ML 包格式
  • Models:模型版本 + 多 framework 统一接口
  • Registry:模型生命周期(staging / production / archived)

装 + 启服务

uv add mlflow
# 启动 tracking server(默认 SQLite 后端 + 本地文件 artifact)
mlflow server \
  --host 0.0.0.0 --port 5000 \
  --backend-store-uri sqlite:///mlflow.db \
  --default-artifact-root ./mlruns

或更"生产"配置(PostgreSQL + S3):

mlflow server \
  --host 0.0.0.0 --port 5000 \
  --backend-store-uri postgresql://user:pass@db/mlflow \
  --default-artifact-root s3://my-bucket/mlruns \
  --workers 4

systemd unit + nginx 套一下就是企业级服务。

Tracking:训练时记录

import mlflow
import mlflow.pytorch

mlflow.set_tracking_uri('http://localhost:5000')
mlflow.set_experiment('churn-prediction')

with mlflow.start_run(run_name='lgbm-baseline'):
    mlflow.log_params({
        'model': 'lgbm',
        'lr': 0.05,
        'n_estimators': 200,
        'max_depth': 7,
    })
    mlflow.set_tag('dataset_version', 'v2024-05-01')

    # 训练
    model = train(...)
    eval_metrics = evaluate(model, X_val, y_val)

    mlflow.log_metrics(eval_metrics)
    # {'auc': 0.84, 'precision': 0.71, 'recall': 0.66}

    # 多 step:每 epoch log
    for epoch in range(20):
        train_one_epoch()
        mlflow.log_metric('train/loss', loss, step=epoch)
        mlflow.log_metric('val/auc', auc, step=epoch)

    # 保存模型(mlflow 自动 capture 依赖 env)
    mlflow.lightgbm.log_model(model, 'model')

    # 任意 artifact
    mlflow.log_artifact('confusion_matrix.png')
    mlflow.log_artifact('feature_importance.csv')

跑完 → MLflow UI 看到 run,有 params / metrics 表 + 曲线 + artifacts
下载。

实验对比

UI 选多个 run → Compare → 表格 + parallel coordinates + scatter plot。
一眼看出"哪几个超参组合 auc 高"。

Model Registry

# 训练完成后注册到 registry
mlflow.lightgbm.log_model(
    lgb_model=model,
    artifact_path='model',
    registered_model_name='ChurnPredictor',
)

UI 里看到 ChurnPredictor v1

版本管理 + 状态机:

client = mlflow.MlflowClient()

# 升级到 staging
client.transition_model_version_stage(
    name='ChurnPredictor', version=1, stage='Staging',
)

# 验证后升级到 production
client.transition_model_version_stage(
    name='ChurnPredictor', version=1, stage='Production',
    archive_existing_versions=True,   # 老 production 自动 archive
)

线上代码总是拿 production 版本:

model = mlflow.pyfunc.load_model(
    model_uri='models:/ChurnPredictor/Production'
)
prediction = model.predict(X)

回滚?把老版本 transition 回 Production 即可。

Models:统一 framework 接口

# sklearn
mlflow.sklearn.log_model(model, 'm')

# PyTorch
mlflow.pytorch.log_model(model, 'm')

# Transformers
mlflow.transformers.log_model(pipeline, 'm')

# 自定义(Pyfunc)
class MyModel(mlflow.pyfunc.PythonModel):
    def load_context(self, context):
        self.artifact = load(context.artifacts['my_file'])
    def predict(self, context, input_df):
        return my_inference(input_df, self.artifact)

mlflow.pyfunc.log_model('m', python_model=MyModel(),
                        artifacts={'my_file': 'data.pkl'})

加载时不需要知道 framework

model = mlflow.pyfunc.load_model('models:/X/Production')
model.predict(df)

业务代码不再 hardcode "import torch / sklearn / xgboost"。

部署:一行 serve

mlflow models serve -m models:/ChurnPredictor/Production -p 5001
# 起一个 HTTP API on :5001
curl -X POST http://localhost:5001/invocations \
  -H 'Content-Type: application/json' \
  -d '{"dataframe_records": [{"age": 30, "balance": 1000}]}'
# {"predictions": [0.72]}

或 build Docker image:

mlflow models build-docker -m models:/X/Production -n my-model:latest
docker run -p 5001:8080 my-model:latest

或部署到 Sagemaker / Azure ML / K8s:

mlflow sagemaker deploy ...
mlflow azureml deploy ...

framework 抽象一直延伸到部署。

Autologging

mlflow.sklearn.autolog()
# 之后所有 sklearn fit() 自动 log model + params + metrics

model = RandomForestClassifier(n_estimators=100)
model.fit(X, y)
# 自动 log: n_estimators / max_depth / mean_score / training_time / ...

支持 sklearn / PyTorch / Lightning / TensorFlow / XGBoost / LightGBM。

适合"快速 baseline 跑 N 个 algorithm 选最好的"。

Projects:可复现 ML 包

MLproject 文件:

name: churn-prediction

python_env: python_env.yaml

entry_points:
  main:
    parameters:
      data_path: {type: string, default: 'data/train.parquet'}
      lr: {type: float, default: 0.05}
      n_estimators: {type: int, default: 200}
    command: 'python train.py --data {data_path} --lr {lr} --n {n_estimators}'

跑:

mlflow run . -P lr=0.1 -P n_estimators=500
# 或 git URL:
mlflow run https://github.com/me/churn-prediction.git -P lr=0.1

自动建 virtualenv + 装依赖 + 跑训练。同事任何人能复现。

与替代品对比

MLflow Weights & Biases Neptune DVC
自托管 ✅ 简单 企业版
实验跟踪 ✅ 最强 较弱
模型注册
部署 ✅ 内置
Pipeline ❌(外部 Airflow) weave
价格 免费 付费(个人免费) 付费 免费

数据合规要本地的 → MLflow。
体验最好 + 不担心数据外发 → wandb。

我们的实战配置

docker-compose.yml

services:
  mlflow:
    image: ghcr.io/mlflow/mlflow:v2.16.2
    ports: ["5000:5000"]
    environment:
      MLFLOW_S3_ENDPOINT_URL: http://minio:9000
      AWS_ACCESS_KEY_ID: minio
      AWS_SECRET_ACCESS_KEY: ...
    command: >
      mlflow server
      --host 0.0.0.0 --port 5000
      --backend-store-uri postgresql://mlflow:pw@pg/mlflow
      --default-artifact-root s3://mlflow-artifacts/

  pg:
    image: postgres:16
    environment:
      POSTGRES_DB: mlflow
      POSTGRES_USER: mlflow
      POSTGRES_PASSWORD: pw

  minio:
    image: minio/minio
    command: server /data --console-address ":9001"
    ports: ["9000:9000", "9001:9001"]
    environment:
      MINIO_ROOT_USER: minio
      MINIO_ROOT_PASSWORD: ...

三个 service:MLflow + Postgres (metadata) + Minio (S3 兼容 artifact)。
全本地,数据不出公司。

效果

  • 30+ 个实验跑下来"哪个 lr × dataset 最优"清晰
  • 上线模型有版本号 + git commit 关联 + 训练数据版本 traceable
  • 回滚到上版本:UI 点 button + 30 秒 redeploy
  • 数据科学家 / ML 工程师 / 部署运维各看 UI 不同部分

踩过的坑

  1. artifact 上传慢:local file backend OK;S3 时大模型几 GB 上传
    几分钟。mlflow.log_artifact 阻塞 train script。后台 thread 异步。

  2. autolog 误 log 整张 dataframe:默认 sklearn autolog 会 log
    X_train shape + 部分 sample。私密数据可能进 MLflow → 安全隐患。
    mlflow.sklearn.autolog(log_input_examples=False)

  3. Model registry 没强制 stage gating:任何人能把 dev model 推
    Production。生产建 ACL + reviewer 流程。

  4. 跨 Python 版本 model 加载失败:在 3.10 训练的 sklearn model
    在 3.12 load 时 unpickle 错。MLflow log model 时 capture 了
    python_env.yaml,确认部署机器装对版本。

  5. UI 慢:实验数量 > 几万后 list 慢。定期 archive 老 experiment
    s3://archived/

精确评价 共 0 人评价
可复现性
可复现 · 0 不可复现 · 0
文风
文风流畅 · 0 文风晦涩 · 0
立场
支持 · 0 反对 · 0

登录后即可对本帖作出评价。

评论区 0 条 · 所有人可在此交流

登录后参与评论。

还没有评论,来说两句。