💡

MLflowのモデル管理を学ぼう section1: 基本のコード

2025/01/28に公開

本記事の立ち位置

MLflowによるモデル管理のベストプラクティスを探るための一つ目の記事にしたいと考えている。
MLflowには学習の記録であるExperimentsと、Modelsが存在する。


Experimentsは文字通り訓練の記録であるが、 ModelsはModel自体を管理することができる。
このコンセプトについては別記事にまとめることにし、本記事ではあくまでもざっくりとしたコードの説明と使い方に関しての記載に止める。

まずモデルを管理する上で以下のプロセスを踏んでいくことが考えられる。

  1. 訓練を記録する
  2. 訓練の結果をモデルに登録する
  3. モデルを選定し productionに登録する(切り替える)

しかし、この際にどう運用するべきか考えていきたい。
例えば3は本当に必要になのだろうか? 小規模かつ複雑にモデルを更新しないようなプロジェクトにおいては 2の stepまで行い、最新のversionを利用すれば良いのではないか?
また3まで至る場合には、2はどのようなタイミングで行うべきだろうか?基本的に全て登録しておけばよいのか?それとも 評価指標が一定値を超えたものだけ登録したら良いのか、このあたりを考察しながら複数記事に分けて考えたいと思う。

本記事はその第一歩として、まず当たり前のコードを紹介したい。

Use cases

Use cases1: r2 score > 0.8以上の場合にのみモデルを更新させたいケースがあると考える。

import mlflow
import mlflow.sklearn
from sklearn.datasets import load_iris
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error

mlflow.set_tracking_uri("http://localhost:8080")

data = load_iris()
X = data.data
y = data.target

X_train, X_test, y_train, y_test = train_test_split(
    X, y, 
    test_size=0.2, 
    random_state=42
)

experiment_id = "962018762253323550"

with mlflow.start_run(experiment_id=experiment_id) as run:
    model = LinearRegression()
    model.fit(X_train, y_train)

    y_pred = model.predict(X_test)

    r2 = r2_score(y_test, y_pred)
    mse = mean_squared_error(y_test, y_pred)

    mlflow.log_metric("r2_score", r2)
    mlflow.log_metric("mse", mse)

    mlflow.sklearn.log_model(model, artifact_path="model")

    model_uri = f"runs:/{run.info.run_id}/model"

    print(f"Logged data and model in run {run.info.run_id}, {r2=}, {mse=}")
    if r2 > 0.95:
        registered_model = mlflow.register_model(
            model_uri=model_uri,
            name="IrisLinearRegressionModel"
        )
        print("Registered model:", registered_model.name, 
            "version:", registered_model.version)

さらに最新のversionwを呼び出す時は以下のように書ける。

import mlflow
import mlflow.pyfunc
from mlflow.tracking import MlflowClient

mlflow.set_tracking_uri("http://localhost:8080")

client = MlflowClient()
model_name = "IrisLinearRegressionModel"

latest_version_info = client.get_latest_versions(model_name)[0]

latest_version = latest_version_info.version

model_uri = f"models:/{model_name}/{latest_version}"
print(model_uri)
model = mlflow.pyfunc.load_model(model_uri)

Use cases2: 最もスコアの高いモデルを production stageにあげる

run全ての中から最もスコアの高いモデルを productionにあげる

import mlflow
from mlflow.tracking import MlflowClient


mlflow.set_tracking_uri("http://localhost:8080")

client = MlflowClient()

experiment_id = "962018762253323550"
model_name = "IrisLinearRegressionModel"

run = mlflow.search_runs(
  # filter_string = "metrics.r2_score > 0.9",
  experiment_ids = experiment_id,
  order_by = ["metrics.r2_score DESC"]
).iloc[0]

model_uri = f"runs:/{run.run_id}/model"
registered_model = mlflow.register_model(
    model_uri=model_uri,
    name=model_name
)

client.transition_model_version_stage(
    name=model_name,
    version=registered_model.version,
    stage="Production"
)

また、すでに登録されている全てのモデルから最も高いスコアのモデルを登録するには以下のコードで実現可能。

import mlflow
from mlflow.tracking import MlflowClient


mlflow.set_tracking_uri("http://localhost:8080")

client = MlflowClient()

model_name = "IrisLinearRegressionModel"

versions = client.search_model_versions(f"name='{model_name}'")
print(versions)

best_r2 = float("-inf")
best_version = None

for v in versions:
    run_id = v.run_id
    run = client.get_run(run_id)
    metrics = run.data.metrics

    r2 = metrics.get("r2_score", None)
    if r2 is not None and r2 > best_r2:
        best_r2 = r2
        best_version = v.version

if best_version is not None:
    client.transition_model_version_stage(
        name=model_name,
        version=best_version,
        stage="Production"
    )
    print(f"Version {best_version} with r2_score={best_r2} has been promoted to Production.")
else:
    print("No model versions found with r2_score metric.")

次回:これらのベストプラクティスを考える

上記のメソッドが用意されているが、実際にはどのように利用し運用していくべきか、これを実体験と公式ドキュメントをベースに考えてみようと思う。
また stagingという機能を利用しているが、こちらも今後は使われなくなるようだ。 https://mlflow.org/docs/latest/model-registry.html#migrating-from-stages
さらに MLflow registoryのconceptも再確認してみよう。https://mlflow.org/docs/latest/model-registry.html#concepts

Discussion