MLflow Tracking Serverを動かす: AppEngine FE + Cloud IAP ver.
About
GKE (+ Ingress) もしくは App Engine Flexible Environmentに加え Cloud IAPを利用すると、限定公開の MLflow Tracking Server を楽に構成できる。
この記事では、構成が容易な後者のAppEngine FEを利用した方法を紹介する。
目的と前提
- Cloud IAPを利用してMLflow Tracking を楽にそこそこ安全に(not 安価に)動かす。
- MLflow Tracking のバックエンドDBはCloudSQL、Artifact storeはGCSとする。
- MLflow 1.9.1 で確認。AppEngine FEのためのコンテナイメージのPythonは3.6.x。
今のところ自分の所属するプロダクトでは、DataflowのGKEクラスタしか動いていないので、今回はGKEではなくApp Engine FEで済ませた。
※ここでの「セキュア」という表現について
パブリックなインターネットに晒すにあたり、HTTPS対応とIAMによる認証ができることがとりあえずの条件とした。
ここでは扱わない内容と他の方法
- GCPを使わない方法
- GKE (Ingress) + Cloud IAPでやる方法
- MLflow自体に手を入れる方法
- Cloud Run + Cloud Endpointsを使う方法
- Cloud IAPは今のところCloudRunには対応していないので、Endpointsを利用することになりそう。
- Endpoints (Extensive Service Proxy beta 2 がCloudRunでは使える。EnvoyベースのProxy)
- OpenAPIの定義があれば、よしなにしてくれるらしい
- が、MLflowのREST APIに swagger.yaml はない様子で、ProcolBuffersで定義されているだけだった ( mlflow/protos)
本編
サービスアカウントを作成する。以下のようなロールがあれば十分かも。
# Cloud IAPを利用する際には必須
IAP-secured Web App User
# BackendをCloud SQL, Artifact StorageをGCSにしている場合
Cloud SQL Client
Storage Object Creator
Storage Object Viewer
# アプリで使うRoleは適宜追加する
BigQuery User
環境変数の設定
# Service Account `mlflow@the-project.iam.gserviceaccount.com`
GOOGLE_APPLICATION_CREDENTIALS=service_account_key.json
クライアントを利用するときに、Cloud IAPのOAuth2 Client IDとそれに対応したTokenを、Service Accountの権限で取得する。
import os, sys
from google.oauth2 import id_token
from google.auth.transport.requests import Request as AuthRequest
import mlflow
cid = "xxxxxxxxxxxxx.apps.googleusercontent.com"
os.environ["MLFLOW_TRACKING_TOKEN"] = id_token.fetch_id_token(AuthRequest(), cid)
mlflow.set_tracking_uri("https://mlflow-dot-the-project.appspot.com/")
自分の場合、OptunaやLightGBMなどのCallback内で MLflowClient
を利用することが多いので、以下のような関数をMLflow Tracking APIをコールする前に実行して、MLFLOW_TRACKING_TOKEN
を更新するようにしている。
def authorize_mlflow(oauth2_client_id: str = None) -> None:
"""Set valid service-account path to 'GOOGLE_APPLICATION_CREDENTIALS' envvar """
try:
os.environ["MLFLOW_TRACKING_TOKEN"] = id_token.fetch_id_token(
AuthRequest(), oauth2_client_id or os.environ.get("MLFLOW_OAUTH2_CLIENT_ID", "")
)
except GoogleAuthError as e:
logger.debug(e)
logger.warning("OAuth2 token authentication error")
except Exception as e:
logger.debug(e)
logger.warning("Continue without authentication")
仕組み
MLflow 1.9時点では、認証方式としてBASIC認証とBearer Tokenによる認証が利用できる。
公式ドキュメント にあるとおり、これらは MLFLOW_
Prefixの環境変数に与えることで利用できる。
上記の例で、MLflow Tracking ServerのAPIをコールするたびに自前でTokenを取得しているのは、MLflow側にTokenの更新処理が実装されていないため。
付録
app.yaml
の例
-
liveness_check
,readiness_check
には、MLflowの/health
エンドポイントが使える。 - バックエンドDBとしてCloudSQLのインスタンスを指定できる。
runtime: custom
env: flex
service: mlflow
skip_files:
- service_account.json
- ^.*\.venv
- ^.*\.env
- ^.*\.terraform
- ^.*tfvers.*
- ^.*\.tf
entrypoint: ./entrypoint.sh
liveness_check:
path: "/health"
check_interval_sec: 30
timeout_sec: 4
failure_threshold: 2
success_threshold: 2
readiness_check:
path: "/health"
check_interval_sec: 5
timeout_sec: 4
failure_threshold: 2
success_threshold: 2
app_start_timeout_sec: 60
beta_settings:
cloud_sql_instances: {INSTANCE_CONNECTION_NAME}
resources:
cpu: 2
memory_gb: 4
disk_size_gb: 10
manual_scaling:
instances: 1
env_variables:
DB_URI: mysql://{DB_USER}:{PASSWORD}/{DATABASE}?unix_socket=/cloudsql/{INSTANCE_CONNECTION_NAME}
ARTIFACT_ROOT: {GCS_BUCKET}
Dockerfile
- AppEngine FEでは、
app.yaml
と同じディレクトリにあるDockerfile
から、ランタイムイメージをCloud Buildでビルドして利用するので、これも必要。 - 下記の段階では、
python3
コマンドは Python 3.6 だった。
FROM gcr.io/google-appengine/python:2020-06-17-111334
RUN apt update && \
apt install -y --no-install-recommends mysql-client libmysqlclient-dev python3-dev
ENV PYTHONFAULTHANDLER=1 \
PYTHONUNBUFFERED=1 \
PYTHONHASHSEED=random \
# pip:
PIP_NO_CACHE_DIR=on \
PIP_DISABLE_PIP_VERSION_CHECK=on \
PIP_DEFAULT_TIMEOUT=100
ARG MLFLOW_VERSION=1.9.1
RUN echo "Installing MLFlow ${MLFLOW_VERSION}"
RUN pip3 install mlflow[extras]==${MLFLOW_VERSION} mysqlclient
WORKDIR /mlflow
COPY ./entrypoint.sh /mlflow/
RUN chmod +x entrypoint.sh
EXPOSE 80 5000 8080
ENTRYPOINT [ "./entrypoint.sh" ]
entrypoint.sh
- AppEngineで動かすので、8080番ポートを使用する。
#!/bin/bash
HOST=${MLFLOW_TRACKING_HOST:-0.0.0.0}
PORT=${PORT:-8080}
sleep 5s
mlflow db upgrade "${DB_URI}"
mlflow server \
--backend-store-uri "${DB_URI}" \
--default-artifact-root "${ARTIFACT_ROOT}" \
--host "${HOST}" --port "${PORT}"
Discussion