🔗

Snowflake への ELT ワークフローを AWS Step Functions で実装してみた

こんにちは!シンプルフォームの山岸です。

皆さん、ETL / ELT のワークフローエンジンには何を利用されているでしょうか?
Apache Airflow や、その AWS マネージドサービスである Amazon MWAA を利用されている方も多いかと思います。もしくは TROCCOFivetran のようなマネージドなワークフローツールを利用されているかもしれません。

当社では多くのワークフローで AWS Step Functions (SFN) を利用しています。今回、Snowflake への ELT ワークフローを SFN で実装してみたので、本記事ではその内容についてご紹介できればと思います。

背景・課題感

背景

少し前に、「Snowflake × dbt で構築する ELT アーキテクチャ」というタイトルの記事を投稿しました。詳細な説明はそちらに譲りますが、AWS 環境から Snowflake のステージング用スキーマに Load する部分のアーキテクチャは以下のようにしていました。

Glue ジョブは ソース DB から増分データフレーム出力処理を実行しており、処理が完了次第、SNS-SQS 経由で Snowpipe 実行用の Lambda 関数 [1] を呼び出しています。Lambda 関数実行によりリクエストされた Snowpipe でのロードが完了すると、dbt モデル更新などの後続処理を実行できるようになります。

課題感

前述の Lambda 関数による Snowpipe 実行は非同期的であり、関数呼び出しが完了した時点ではまだ Snowflake 環境へのロードは完了していません。しかし、dbt モデル更新などの後続処理を含む一連のデータ更新処理を、イベントドリブンな ELT ワークフローとして自動化するには、Snowflake 環境へのロード完了をチェックする機構が必要になります。

実装

ELT ワークフロー

上記の課題感を踏まえ、「Snowflake 環境へのファイル取り込み」と後続の「dbt モデル更新」を実行する ELT ワークフローとして、以下のような SFN ステートマシンを実装しました。


UpdateTable ワークフローグラフ

このワークフローでは、2つの SFN ワークフローをネストで呼び出しています。各 State の処理内容は以下のようになっています。

  • 前段の IngestFilesWorkflow では、後述の IngestFiles ワークフローを呼び出しています。内部的には Snowpipe を Lambda 関数から実行しており、Snowflake 環境へのロードが完了したら、SFN ステートマシン実行も同期的に完了するようになっています。
  • 前段のワークフローの実行結果からロードされた行数を確認し、0 行でなければ後段の UpdateDbtModels ワークフローを実行します。(0 行の場合はモデル更新も不要のため、そのまま終了します)
  • 後段の UpdateDbtModelsWorkflow では、対象の dbt モデルを更新します。当社ではソース DB の 1 テーブルに対して複数モデルを対応させているケースもあるため、SFN ステートマシンとして一段階抽象化しています。

では前段の IngestFiles ワークフローをどう実装するかについて、以降に解説していきたいと思います。(後段の UpdateDbtModels ワークフローについては割愛します)

IngestFiles ワークフロー

本来は非同期的な Snowpipe によるファイル取り込みの処理を同期的に呼び出すため、以下のような SFN ステートマシンを実装しました。


IngestFiles ワークフローグラフ

ワークフロー定義の実装は以下のようになっています。

ワークフロー定義
{
  "StartAt": "IngestFiles",
  "States": {
    "IngestFiles": {
        "Type": "Task",
        "Resource": "arn:aws:states:::lambda:invoke",
        "ResultSelector": {
          "Response.$": "$.Payload"
        },
        "ResultPath": "$.results.IngestFiles",
        "Parameters": {
          "FunctionName": "${function_arns.ingest_files}",
          "Payload": {
            "execution_name.$": "$$.Execution.Input.execution_name",
            "snowflake_account.$": "$$.Execution.Input.snowflake_account",
            "snowflake_database.$": "$$.Execution.Input.snowflake_database",
            "snowflake_schema.$": "$$.Execution.Input.snowflake_schema",
            "snowflake_table_layer.$": "$$.Execution.Input.snowflake_table_layer",
            "table_name.$": "$$.Execution.Input.table_name"
          }
        },
        "Retry": [
          {
            "ErrorEquals": ["States.ALL"],
            "IntervalSeconds": 10,
            "MaxAttempts": 3,
            "BackoffRate": 2
          }
        ],
        "Next": "GetCopyHistory"
    },
    "GetCopyHistory": {
      "Type": "Task",
      "Resource": "arn:aws:states:::lambda:invoke",
      "ResultSelector": {
        "Response.$": "$.Payload"
      },
      "ResultPath": "$.results.GetCopyHistory",
      "Parameters": {
        "FunctionName": "${function_arns.get_copy_history}",
        "Payload": {
          "execution_name.$": "$$.Execution.Input.execution_name",
          "snowflake_account.$": "$$.Execution.Input.snowflake_account",
          "snowflake_database.$": "$$.Execution.Input.snowflake_database",
          "snowflake_schema.$": "$$.Execution.Input.snowflake_schema",
          "snowflake_table_layer.$": "$$.Execution.Input.snowflake_table_layer",
          "table_name.$": "$$.Execution.Input.table_name"
        }
      },
      "Retry": [
        {
          "ErrorEquals": ["States.ALL"],
          "IntervalSeconds": 10,
          "MaxAttempts": 3,
          "BackoffRate": 2
        }
      ],
      "Next": "CheckLoadedFiles"
    },
    "CheckLoadedFiles": {
      "Type": "Task",
      "Resource": "arn:aws:states:::lambda:invoke",
      "ResultSelector": {
        "Response.$": "$.Payload"
      },
      "ResultPath": "$.results.CheckLoadedFiles",
      "Parameters": {
        "FunctionName": "${function_arns.check_loaded_files}",
        "Payload": {
          "StagedFilePathList.$": "$.results.IngestFiles.Response.body.StagedFilePathList",
          "CopyHistory.$": "$.results.GetCopyHistory.Response.body.CopyHistory"
        }
      },
      "Retry": [
        {
          "ErrorEquals": ["States.ALL"],
          "IntervalSeconds": 10,
          "MaxAttempts": 3,
          "BackoffRate": 2
        }
      ],
      "Next": "CheckLoadStatus"
    },
    "CheckLoadStatus": {
      "Type": "Choice",
      "Choices": [
        {
          "Variable": "$.results.CheckLoadedFiles.Response.statusCode",
          "NumericEquals": 200,
          "Next": "Success"
        },
        {
          "Variable": "$.results.CheckLoadedFiles.Response.statusCode",
          "NumericEquals": 500,
          "Next": "Fail"
        },
        {
          "Variable": "$.results.CheckLoadedFiles.Response.statusCode",
          "NumericEquals": 400,
          "Next": "Wait"
        }
      ],
      "Default": "Fail"
    },
    "Fail": {"Type": "Fail"},
    "Success": {"Type": "Succeed"},
    "Wait": {
      "Type": "Wait",
      "Seconds": 60,
      "Next": "GetCopyHistory"
    }
  }
}

各 State での処理は以下のようになっています。

State 名 State 型 処理内容
IngestFiles Task Snowpipe を実行するための Lambda 関数を実行し、ステージされたファイル一覧を返す。
GetCopyHistory Task Snowpipe の COPY_HISTORY を取得する Lambda 関数を実行し、ロード済みファイル一覧やロードされた合計行数などを返す。
CheckLoadedFiles Task IngestFiles 結果に含まれるステージされたファイル一覧と、GetCopyHistory 結果のロード済みファイル一覧を比較し、Snowflake 環境へのロード完了を判定する。
CheckLoadStatus Choice ロードが完了していなければ、一定時間の待機後に再び GetCopyHistory に遷移し、ロード完了まれこれを繰り返す。

各 State で呼び出される Lambda 関数の実装について、以下に少し補足したいと思います。

IngestFiles 関数

IngestFiles 関数は、Snowpipe 実行を明示的に実行するための Lambda 関数です。snowflake-ingest-python というファイル取り込み用 SDK を利用し、ステージ対象ファイルを指定して Snowpipe によるロードをリクエストします。

実装の詳細については以下のエントリで扱っているので、良ければ併せてご覧ください。

https://zenn.dev/simpleform_blog/articles/20240716-snowpipe-ingestion-with-aws-lambda

GetCopyHistory 関数

GetCopyHistory 関数では、Snowpipe に対応する COPY_HISTORY を取得し、Snowflake 環境へのリアルタイムなロード状況を確認します。ハンドラスクリプトの実装は、例えば以下のようになります。

ハンドラスクリプト実装例
main.py
import json
import os
import pandas as pd

from aws_lambda_powertools import Logger
from snowflake.connector.connection import SnowflakeConnection

from .modules.models import EventMessage
from .modules.utils import get_snowflake_connection
from .modules.enum import AuthMethod

logger = Logger()

AUTH_METHOD = os.environ["AUTH_METHOD"]
LOADED_STATE = "Loaded"


def get_copy_history(conn: SnowflakeConnection, message: EventMessage) -> pd.DataFrame:
    COLUMNS = [
        "file_name",
        "row_count",
        "row_parsed",
        "error_count",
        "error_limit",
        "status",
    ]
    evaluation_period = {
        "value": os.environ.get("EVALUATION_PERIOD_VALUE", 24),
        "unit": os.environ.get("EVALUATION_PERIOD_UNIT", "hours"),
    }
    qualified_table_name = ".".join([
        message.snowflake_database,
        message.snowflake_schema,
        f"{message.snowflake_table_layer}_{message.table_name}",
    ])

    with conn.cursor() as cur:
        query = f"""
        SELECT
            {','.join(COLUMNS)}
        FROM table(
            information_schema.copy_history(
                TABLE_NAME => '{qualified_table_name}',
                START_TIME => DATEADD(
                    {evaluation_period['unit']},
                    -{evaluation_period['value']},
                    CURRENT_TIMESTAMP()
                )
            )
        )
        LIMIT 10;
        """
        logger.info(query)
        cur.execute(query)
        results = cur.fetchall()
        df = pd.DataFrame(results, columns=COLUMNS)
        df = df[df["file_name"].str.startswith(f"execution_name={message.execution_name}")]

    return df


@logger.inject_lambda_context(log_event=True)
def handler(event, context):

    if "Records" in event:
        assert len(event["Records"]) == 1, "Only one record is expected"
        message = json.loads(event["Records"][0]["body"])
        message = EventMessage.model_validate(message)
    else:
        message = EventMessage.model_validate(event)

    conn = get_snowflake_connection(AUTH_METHOD)(message)
    df = get_copy_history(conn, message)

    df_loaded = df[df["status"] == LOADED_STATE]
    loaded_file_path_list = df_loaded["file_name"].tolist()

    total_row_count = int(df["row_count"].sum())
    total_error_count = int(df["error_count"].sum())

    if total_error_count == 0:
        return {
            "statusCode": 200,
            "body": {
                "TotalRowCount": total_row_count,
                "CopyHistory": df.to_dict(orient="records"),
                "LoadedFilePathList": loaded_file_path_list,
            }
        }

    else:
        return {
            "statusCode": 500,
            "body": {
                "TotalRowCount": total_row_count,
                "TotalErrorCount": total_error_count,
                "CopyHistory": df.to_dict(orient="records"),
                "LoadedFilePathList": loaded_file_path_list,
            }
        }

このうち、COPY_HISTORY を取得しているのは以下の部分になります。この実装では、対象期間、および該当の ELT ワークフロー実行名で取得する行を絞っています。

main.py - get_copy_history()
def get_copy_history(conn: SnowflakeConnection, message: EventMessage) -> pd.DataFrame:
    COLUMNS = ["file_name", "row_count", "row_parsed", "error_count", "error_limit", "status", ]
    evaluation_period = {
        "value": os.environ.get("EVALUATION_PERIOD_VALUE", 24),
        "unit": os.environ.get("EVALUATION_PERIOD_UNIT", "hours"),
    }
    qualified_table_name = ".".join([
        message.snowflake_database,
        message.snowflake_schema,
        f"{message.snowflake_table_layer}_{message.table_name}",
    ])

    with conn.cursor() as cur:
        query = f"""
        SELECT
            {','.join(COLUMNS)}
        FROM table(
            information_schema.copy_history(
                TABLE_NAME => '{qualified_table_name}',
                START_TIME => DATEADD({evaluation_period['unit']}, -{evaluation_period['value']}, CURRENT_TIMESTAMP())
            )
        );
        """
        cur.execute(query)
        results = cur.fetchall()
        df = pd.DataFrame(results, columns=COLUMNS)
        df = df[df["file_name"].str.startswith(f"execution_name={message.execution_name}")]

    return df

CheckLoadedFiles 関数

ワークフロー定義を見て頂けると分かる通り、IngestFiles 関数と GetCopyHistory 関数のレスポンスは、SFN ステートマシン実行の ResultPath に書き込まれるようにしています。

それぞれのレスポンスに含まれる「ステージされたファイル一覧」と「ロード済みのファイル一覧」の情報を抽出し、CheckLoadedFiles 関数への入力とします。両者の結果を比較し、リクエストされたロードが完了しているかどうかを判定します。

細部の説明は割愛しますが、ハンドラスクリプトの実装例を以下に付しておきます。

ハンドラスクリプト実装例
from aws_lambda_powertools import Logger

logger = Logger()

LOADED_STATE = "Loaded"


@logger.inject_lambda_context(log_event=True)
def handler(event, context):

    staged_file_path_list = event["StagedFilePathList"]
    logger.info(staged_file_path_list)

    copy_history = event["CopyHistory"]
    logger.info(copy_history)

    errors = []
    for record in copy_history:
        logger.info(record)

        if record["error_count"] == 0:
            if record["status"] == LOADED_STATE:
                loaded_file_path = record["file_name"]
                logger.info(loaded_file_path)
                staged_file_path_list.remove(loaded_file_path)

        elif record["error_count"] > 0:
            errors.append(record)

    if len(errors) > 0:
        msg = "Some files failed to load"
        return {
            "statusCode": 500,
            "body": {
                "Message": msg,
                "Errors": errors,
            }
        }

    else:
        if len(staged_file_path_list) == 0:
            msg = "All files loaded successfully"
            return {
                "statusCode": 200,
                "body": {
                    "Message": msg,
                }
            }

        else:
            msg = "Some files are not loaded"
            return {
                "statusCode": 400,
                "body": {
                    "Message": msg,
                    "UnloadedFilePathList": staged_file_path_list,
                }
            }

実装に関する説明は以上です。

さいごに

Snowflake 環境への ELT ワークフロー実装について書いてみました。

ワークフローエンジンとして SFN を利用している事例もいくつか見かけつつ、Snowflake との組合せという意味では、調べている限りでは Airflow の方が主流なのかなという印象を受けました。もちろん Airflow も優れたツールだと思いますが、個人的にはサーバー不要でコスト効率の高い SFN が好みで、これからもどんどん活用していきたいと思っています。

本記事が読者の皆さまの技術選定や実装の参考になれば幸いです。
最後まで読んで頂き、ありがとうございました。

参考

脚注
  1. Snowpipe によるファイル取り込みを AWS Lambda から動かしてみる - Zenn ↩︎

Snowflake Data Heroes

Discussion