😊

Airflowの単体テストを書きましょう

2023/04/06に公開

データ基盤は下流の分析・可視化・モデリングの「基盤」となるので、品質の担保は言うまでもなく重要ですね。品質を確保するには、ワークフローの監視・検証、ワークフローのテスト、そして加工用クエリのテストがいずれも欠かせません。この記事では、ワークフロー(Airflow)の単体テスト方法について紹介します。また、ワークフローの監視・検証に関しては、過去の記事も合わせてご覧いただけると幸いです。

ワークフローの監視
https://buildersbox.corp-sansan.com/entry/2022/08/18/110000

ワークフローの検証
https://zenn.dev/jcc/articles/279a16eb36ef72

DAGの単体テスト

まずは、DAGの単体テストについて説明します。厳密に言えば、DAGの実行ではなく、DAGが正確に構築されたかどうかのテストを行います。

https://airflow.apache.org/docs/apache-airflow/stable/best-practices.html#unit-tests

Airflowの公式ベストプラクティスでは簡潔に紹介されていますが、具体例を挙げてさらに詳しく説明しましょう。

importのテスト

  • importが正常にできることを確認する(importが失敗するとWeb UIからも確認できるが、単体テストする時点で確認するともっと便利)
  • import時間を制限する。(冗長なDAGがあると解析するのに時間がかかるので、import時間を制限することで、事前に冗長なDAGを発見できる)
  • 最低でも1つのタスクが含まれていることを確認する。
import unittest
from datetime import timedelta

from airflow.models import DagBag


class TestImportDags(unittest.TestCase):
    IMPORT_TIMEOUT = 120

    @classmethod
    def setUpClass(cls) -> None:
        cls.dagbag = DagBag()
        cls.stats = cls.dagbag.dagbag_stats

    def test_import_dags(self):
        self.assertFalse(
            len(self.dagbag.import_errors),
            f"DAG import failures. Errors: {self.dagbag.import_errors}",
        )

    def test_import_dags_time(self):
        duration = sum((o.duration for o in self.stats), timedelta()).total_seconds()
        self.assertLess(duration, self.IMPORT_TIMEOUT)

    def test_dags_have_at_least_one_task(self):
        for key, dag in self.dagbag.dags.items():
            self.assertTrue(dag, f"DAG {key} not exsit")
            self.assertGreater(len(dag.tasks), 0, f"DAG {key} hasn't any tasks")

DAGの設定

  • DAGタイムアウト設定が必ず行われていることを確認する
  • 失敗時のアラート通知用のcallback関数が必ず行われていることを確認する
  • DAGにタグが付与されていることを確認する(タグはDAGの分類や検索に役立つ)
  • catchupを無効になっているかを確認する(場合によって有効でも構わない)
class TestDagsSetting(unittest.TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        cls.dagbag = DagBag()

    def test_dag_has_necessary_setting(self):
        for key, dag in self.dagbag.dags.items():
            self.assertIsNotNone(dag.dagrun_timeout, f"DAG {key} hasn't dagrun timeout")
            self.assertIsNotNone(
                dag.on_failure_callback, f"DAG {key} hasn't on_failure_callback"
            )
            self.assertIsNotNone(dag.tags, f"DAG {key} hasn't tags")
            self.assertFalse(dag.catchup, f"catchup of DAG {key} was not False")

これらのテストを通すことで、DAGの基本的な構成が適切であることが確認でき、DAGの運用上の問題を未然に防ぐことができます。

Operatorの単体テスト

次に、自作Operatorの単体テストについて説明します。公式のprovidersがない、または使いづらい場合、Extract-Load処理のためにOperatorを自作するのは一般的です。以下の例ではSalesforceHookGCSHookを利用して実装したSalesforceBulkToGcsOperatorの単体テストです。

手順は以下の通りです。

  1. 必要なモックオブジェクトを作成する
  2. テストパラメータを設定する(Connectionを利用したテストケースの場合、tearDownメソッドで適切に削除しましょう。)
  3. executeメソッドを呼び出し、期待する動作が行われるかどうかを検証する

実際にどのように記述すべきか分からない場合は、Airflowのリポジトリに多数のOperatorをテストする例が存在するので、それらを参考にすることができます。

from __future__ import annotations

import unittest
from collections import OrderedDict
from unittest import mock

from airflow import settings
from airflow.models import Connection
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.salesforce.hooks.salesforce import SalesforceHook
from dags.sfdc2bq.operators import SalesforceBulkToGcsOperator

TASK_ID = "test-task-id"
CHUNK_START = "2021-12-20T17:49:43Z"
CHUNK_END = "2022-06-04T10:43:20Z"
OBJ_SFDC = "Lead"
GCS_BUCKET = "test-bucket"
GCS_OBJECT_PATH = "path/to/test-file-path"
EXPECTED_GCS_URI = f"gs://{GCS_BUCKET}/{GCS_OBJECT_PATH}"
SALESFORCE_CONNECTION_ID = "test_salesforce_connection"
GCP_CONNECTION_ID = "test_google_cloud"
SALESFORCE_RESPONSE = {
    "records": [
        OrderedDict(
            [
                (
                    "attributes",
                    OrderedDict(
                        [
                            ("type", "Lead"),
                            (
                                "url",
                                "/services/data/v42.0/sobjects/Lead/00Q3t00001eJ7AnEAK",
                            ),
                        ]
                    ),
                ),
                ("Id", "00Q3t00001eJ7AnEAK"),
                ("Company", "Hello World Inc"),
            ]
        )
    ],
    "totalSize": 1,
    "done": True,
}
SALESFORCE_RESPONSE_NO_RECORDS = {
    "records": "",
    "totalSize": 0,
    "done": True,
}
SCHEMA = [
    {"mode": "NULLABLE", "name": "id", "type": "STRING"},
    {"mode": "NULLABLE", "name": "company", "type": "STRING"},
]

class TestSalesforceBulkToGcsOperator(unittest.TestCase):
    def setUp(self):
        mock_google_cloud_default_conn = Connection(
            conn_id=GCP_CONNECTION_ID,
            conn_type="google_cloud_platform",
            extra="{}",
        )

        mock_test_salesforce_conn = Connection(
            conn_id=SALESFORCE_CONNECTION_ID, conn_type="http", extra="{}"
        )

        session = settings.Session()
        session.add(mock_google_cloud_default_conn)
        session.add(mock_test_salesforce_conn)
        session.commit()

    def tearDown(self):
        session = settings.Session()
        for conn_id in [SALESFORCE_CONNECTION_ID, GCP_CONNECTION_ID]:
            connection = (
                session.query(Connection)
                .filter(Connection.conn_id == conn_id)
                .one_or_none()
            )
            if connection:
                session.delete(connection)
        session.commit()

    @mock.patch.object(GCSHook, "upload")
    @mock.patch.object(SalesforceHook, "object_to_df")
    @mock.patch.object(SalesforceHook, "make_query")
    def test_execute(self, mock_make_query, mock_object_to_df, mock_upload):
        mock_make_query.return_value = SALESFORCE_RESPONSE

        operator = SalesforceBulkToGcsOperator(
            obj=OBJ_SFDC,
            chunk_start=CHUNK_START,
            chunk_end=CHUNK_END,
            bucket_name=GCS_BUCKET,
            object_name=GCS_OBJECT_PATH,
            salesforce_conn_id=SALESFORCE_CONNECTION_ID,
            schema=SCHEMA,
            gcp_conn_id=GCP_CONNECTION_ID,
            export_format="csv",
            task_id=TASK_ID,
        )
        result = operator.execute({})

        mock_make_query.assert_called_once_with(
            query="SELECT id,company FROM Lead WHERE CreatedDate>2021-12-20T17:49:43Z and CreatedDate<=2022-06-04T10:43:20Z"  # NOQA E501
        )

        mock_object_to_df.assert_called_once_with(
            query_results=SALESFORCE_RESPONSE["records"],
        )

        mock_upload.assert_called_once_with(
            bucket_name=GCS_BUCKET,
            object_name=GCS_OBJECT_PATH,
            filename=mock.ANY,
            gzip=False,
        )

        assert EXPECTED_GCS_URI == result

    @mock.patch.object(GCSHook, "upload")
    @mock.patch.object(SalesforceHook, "object_to_df")
    @mock.patch.object(SalesforceHook, "make_query")
    def test_execute_non_response(
        self, mock_make_query, mock_object_to_df, mock_upload
    ):
        mock_make_query.return_value = SALESFORCE_RESPONSE_NO_RECORDS

        operator = SalesforceBulkToGcsOperator(
            obj=OBJ_SFDC,
            chunk_start=CHUNK_START,
            chunk_end=CHUNK_END,
            bucket_name=GCS_BUCKET,
            object_name=GCS_OBJECT_PATH,
            salesforce_conn_id=SALESFORCE_CONNECTION_ID,
            schema=SCHEMA,
            gcp_conn_id=GCP_CONNECTION_ID,
            export_format="csv",
            task_id=TASK_ID,
        )
        result = operator.execute({})

        mock_make_query.assert_called_once_with(
            query="SELECT id,company FROM Lead WHERE CreatedDate>2021-12-20T17:49:43Z and CreatedDate<=2022-06-04T10:43:20Z"
        )
        mock_object_to_df.assert_not_called()
        mock_upload.assert_not_called()

        assert result is None

単体テストをCIに組み込む

GitHub Actionsを使ってAirflowの単体テストを実行する方法として、公式のdocker-compose.yamlを利用して環境を構築し、workflow-schedulerコンテナ内で単体テストを実行します。ワークフローの例は以下の通りです。

name: airflow unit test

on:
  pull_request:
    paths:
      - "workflows/**"
      - ".github/workflows/airflow.yml"

jobs:
  airflow-unit-test:
    runs-on: ubuntu-latest
    steps:
      - name: Check out source repository
        uses: actions/checkout@v3
      - name: Build Container
        run: echo "AIRFLOW_UID=$(shell id -u)\nAIRFLOW_GID=0" > .env && docker-compose up -d
        working-directory: workflows
      - name: Run Test
        run: sleep 30 && docker exec workflow-scheduler python -m unittest -v tests
        working-directory: workflows

コンテナを立ち上げた直後にdocker exec workflow-scheduler python -m unittest -v testsを実行すると、ModuleNotFoundError: No module named xxxエラーが発生することがあるため、実行前sleep 30で30秒待機してからテストを実行します。これにより、環境が十分に立ち上がり、モジュールが読み込まれるまでの時間を確保できます。

DAG実行のテスト

DAGが正確に構築されているかどうかのテストするための具体例を示しましたが、実際にDAGの実行をテストをしたい場合結合テストを実施することをお勧めします。ベストプラクティスに記載されているように、実際にDAGを実行することで、タスク間の依存関係やデータのやり取りが正しく機能しているかを確認できます。

以上、Airflowの単体テスト方法について紹介しました。ご参考になれば幸いです。

Discussion