🔖

【GCP】Airflow(on Cloud Composer)でAI PlatformのTrainingを実行する

3 min read

これは何?

GCPのAI Platform(以下AIPF)にあるトーレニングの機能をAirflow(on Cloud Composer)で実行管理をするという趣旨になります。

どういう課題を解決するのか

  • 機械学習におけるトレーニング(学習)工程を、クラウド上の計算リソースを用いる事で非同期かつ並列で実行できます。これにより、大量のデータを用いた学習など多大な計算リソースを必要とする要件を満たすことが出来ます。
  • また、Airflowを用いることでワークフローエンジンによる実行のタイミング制御や依存関係の構築(データマートを作成したあとにAIPFのジョブを起動させる、等)が可能となります。
    • 単に実行のタイミング制御ですとCloud Scheduler->Pub/Sub->cloud dunctionなどでも実現可能かと思います。
  • Cloud Composerを用いることでAirflowの構築やメンテナンスが容易になりアプリケーション開発に集中できます。

AI Plarformの概要

AIに関する機能群を包括したサービスの概念だと認識しています。

AI PlatformというGCPのサービスには以下の機能が提供されています。

  • AI Platfrom Training
  • AI Platfrom Prediction
  • AI Platfrom Vizier
  • AI Platfrom Notebooks
  • and more

この記事では AI Platfrom Training を使用しています。

Ref https://cloud.google.com/ai-platform/docs?hl=ja

Cloud Composerの概要

  • Airflowのマネージドサービスです。
    • MySQLを構築する時にGCEを作成して自前で管理するか、CloudSQL使うか、という話と抽象的には同じです。
  • ComposerでググってもPHP文脈の記事がでるので調べる時は「GCP」とかの単語を増やして検索します。

Airflowの概要

PythonでDAGを作成、スケジュール、管理ができるプラットフォームです。各taskを作成し、task間の依存関係を構築できます。

詳細

実行方法

Traning

実装としてはAIPFのサンプル(https://github.com/GoogleCloudPlatform/ai-platform-samples/tree/master/training/sklearn/structured/base)を使用しました。

Airflow

Airflow上ではBashOperatorを使用してshellを実行するのですが、以下のような実装になります。

dag.py
from composer.dags.dor_poc.create_datalake.dag_dwh import DAG_NAME
from datetime import datetime, timedelta, timezone

from airflow import DAG
from airflow.utils.dates import days_ago
from airflow.models import Variable
from airflow.operators.bash_operator import BashOperator

# config
DAG_NAME = "training_model"
PROJECT_ID = Variable.get("project_id")
BUCKET_NAME = Variable.get("bucket_name")
JST = timezone(timedelta(hours=+9), "JST")

default_args = {
    "start_date": days_ago(0),
}

with DAG(
    DAG_NAME,
    schedule_interval=None,
    catchup=False,
    default_args=default_args) as dag:

    job_id = f"{task_id}_{datetime.now(JST).strftime('%Y%m%d%H%M')}"

    task_training_model = BashOperator(
        task_id="task1",
        bash_command=f'''
            gcloud ai-platform jobs submit training "{job_id}" \
            --job-dir="gs://{BUCKET_NAME}" \
            --package-path=/home/airflow/gcs/dags/sample/trainer \
            --module-name=trainer.task \
            --region="asia-northeast1" \
            --runtime-version=2.2 \
            --python-version=3.7 \
            --scale-tier=BASIC \
            -- \
            --input="/home/airflow/gcs/dags/sample/datasets/downlaod-taxi.sh" \
            --n-estimators=20 \
            --max-depth=3
        '''
    )

参考リンク