🔖
【GCP】Airflow(on Cloud Composer)でAI PlatformのTrainingを実行する
これは何?
GCPのAI Platform(以下AIPF)にあるトーレニングの機能をAirflow(on Cloud Composer)で実行管理をするという趣旨になります。
どういう課題を解決するのか
- 機械学習におけるトレーニング(学習)工程を、クラウド上の計算リソースを用いる事で非同期かつ並列で実行できます。これにより、大量のデータを用いた学習など多大な計算リソースを必要とする要件を満たすことが出来ます。
- また、Airflowを用いることでワークフローエンジンによる実行のタイミング制御や依存関係の構築(データマートを作成したあとにAIPFのジョブを起動させる、等)が可能となります。
- 単に実行のタイミング制御ですとCloud Scheduler->Pub/Sub->cloud dunctionなどでも実現可能かと思います。
- Cloud Composerを用いることでAirflowの構築やメンテナンスが容易になりアプリケーション開発に集中できます。
AI Plarformの概要
AIに関する機能群を包括したサービスの概念だと認識しています。
AI PlatformというGCPのサービスには以下の機能が提供されています。
- AI Platfrom Training
- AI Platfrom Prediction
- AI Platfrom Vizier
- AI Platfrom Notebooks
- and more
この記事では AI Platfrom Training
を使用しています。
Ref https://cloud.google.com/ai-platform/docs?hl=ja
Cloud Composerの概要
- Airflowのマネージドサービスです。
- MySQLを構築する時にGCEを作成して自前で管理するか、CloudSQL使うか、という話と抽象的には同じです。
- ComposerでググってもPHP文脈の記事がでるので調べる時は「GCP」とかの単語を増やして検索します。
Airflowの概要
PythonでDAGを作成、スケジュール、管理ができるプラットフォームです。各taskを作成し、task間の依存関係を構築できます。
詳細
実行方法
Traning
実装としてはAIPFのサンプル(https://github.com/GoogleCloudPlatform/ai-platform-samples/tree/master/training/sklearn/structured/base)を使用しました。
Airflow
Airflow上ではBashOperatorを使用してshellを実行するのですが、以下のような実装になります。
dag.py
from composer.dags.dor_poc.create_datalake.dag_dwh import DAG_NAME
from datetime import datetime, timedelta, timezone
from airflow import DAG
from airflow.utils.dates import days_ago
from airflow.models import Variable
from airflow.operators.bash_operator import BashOperator
# config
DAG_NAME = "training_model"
PROJECT_ID = Variable.get("project_id")
BUCKET_NAME = Variable.get("bucket_name")
JST = timezone(timedelta(hours=+9), "JST")
default_args = {
"start_date": days_ago(0),
}
with DAG(
DAG_NAME,
schedule_interval=None,
catchup=False,
default_args=default_args) as dag:
job_id = f"{task_id}_{datetime.now(JST).strftime('%Y%m%d%H%M')}"
task_training_model = BashOperator(
task_id="task1",
bash_command=f'''
gcloud ai-platform jobs submit training "{job_id}" \
--job-dir="gs://{BUCKET_NAME}" \
--package-path=/home/airflow/gcs/dags/sample/trainer \
--module-name=trainer.task \
--region="asia-northeast1" \
--runtime-version=2.2 \
--python-version=3.7 \
--scale-tier=BASIC \
-- \
--input="/home/airflow/gcs/dags/sample/datasets/downlaod-taxi.sh" \
--n-estimators=20 \
--max-depth=3
'''
)
参考リンク
- https://github.com/GoogleCloudPlatform/ai-platform-samples/tree/master/training/sklearn/structured/base
- https://cloud.google.com/ai-platform/training/docs/training-jobs
- https://cloud.google.com/ai-platform/training/docs/training-scikit-learn
- https://cloud.google.com/ai-platform/training/docs/packaging-trainer
Discussion