🧪

SQLAlchemy+aiomysql+pytest-asyncioを使ったデータベースのテストの書き方

2024/02/13に公開

pytest-asyncio も SQLAlchemy も優秀なのでそんなに困ることはないのですが、いくつかはまりどころがあったのでメモしておきます。

aiomysql では sync_engine が使えない

aiomysql は asyncio でしか使えないようになっているので、sync_engine が利用できません。あまり困ることはないですが、 click などを使う時でも async を強制されるので少し面倒です。 asyncio.run() で囲みましょう
データベース全体の操作は SQLAlchemy の少し外側に出ることが多いので、 async を対応させるのはひと手間かかります。具体的には SQLAlchemy-Utils が使えません。当初は強引に mysql+aiomysqlmysql+pymysql に書き換えて create_engine していたのですが、さすがにどうかと思ったのでちゃんと書きました。(結局 engine を2つ作ることになるので、それでもいいような気もします)

スキーマの作成は以下のコードで行うことができます。

    async with engine.connect() as conn:
        await conn.run_sync(Base.metadata.create_all)

DB 自体の作成削除はもしかしたら他に方法があるかもしれませんが、 SQLAlchemy-Utils を模倣して直接書いてしまうのがいいと思います。

from sqlalchemy import text
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.pool import NullPool

# 以下どこかの関数内

    database_url = "mysql+aiomysql://root:password@localhost:3306/aiodev"
    parsed_url = make_url(database_url)
    database_name = str(parsed_url.database)
    database_engine = create_async_engine(parsed_url._replace(database=None), poolclass=NullPool)

    # create
    async with database_engine.begin() as conn:
        sql = f"CREATE DATABASE IF NOT EXISTS {database_name} CHARACTER SET = 'utf8mb4'"
        await conn.execute(text(sql))

    # drop
    async with database_engine.begin() as conn:
        sql = f"DROP DATABASE IF EXISTS {database_name}"
        await conn.execute(text(sql))

ちなみに psycopg だと sync_engine は使えたはずです。

session スコープの fixture でこの処理を行うと、テストの初期化と終了時に DB の用意ができます。

コネクションプーリングを切る

SQLAlchemy + aiomysql で pytest-asyncio を使ったテストを何も考えずに書くといろいろエラーが出ます。コネクションプーリングしてしまったコネクションが抱えているイベントループが再利用された時に別のイベントループで実行しようとしてしまうからです。(多分 aiomysql じゃなくても)
先程のコードにも出てきていますが、 poolclass=NullPool でコネクションプーリングを無効にしましょう。

from sqlalchemy.pool import NullPool

create_async_engine(database_url, echo=True, poolclass=NullPool)

テスト以外ではコネクションプーリングを使うことになると思いますので、テスト時のエンジン作成処理だけ poolclass を別途指定するようにするのが良いでしょう。

pytest-asyncio 0.23 で "DeprecationWarning: There is no current event loop" が出る

https://github.com/pytest-dev/pytest-asyncio/issues/706 このあたりで議論されています。経緯は追ってないのでよくわかりません。

DB の初期化などは session スコープの fixture で行われると思いますが、その場合に event_loop を session スコープにしないとイベントループが途中で変わるせいなのか、最後に表題のエラーが発生します。

テストに対して @pytest.mark.asyncio(scope="session") とするとイベントループのスコープを指定でき、 https://pytest-asyncio.readthedocs.io/en/latest/how-to-guides/run_session_tests_in_same_loop.html に書いてある方法で全てのテストケースに強制できるはずなのですが、 fixture には適用されません。

event_loop fixture を上書きするのは非推奨になったので、現時点での正しい方法がいまいちよくわからないのですが、以下の方法で強引に回避できることは確認しました。

import asyncio
from collections.abc import Iterator

from pytest import FixtureRequest


@pytest.fixture(scope="session")
def event_loop(request: FixtureRequest) -> Iterator[asyncio.AbstractEventLoop]:
    loop = asyncio.get_event_loop_policy().new_event_loop()
    loop.__original_fixture_loop = True  # type: ignore[attr-defined]
    yield loop
    loop.close()

@pytest.fixture(scope="session")
def engine(database_url: str, event_loop: asyncio.AbstractEventLoop) -> AsyncEngine:
    """
    SQLAlchemy Engineの作成処理
    ここの引数に event_loop を入れておくと、確実に事前に初期化される
    """
    ...

ただしメッセージを無視すればいいだけなので、強引に回避しなくてもいいと思います。

# pyproject.toml

[tool.pytest.ini_options]
filterwarnings = '''
  ignore:There is no current event loop
'''

実際のコード

conftest.py は以下のようになります。

from collections.abc import AsyncGenerator

import pytest
import pytest_asyncio
from sqlalchemy import text
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.pool import NullPool

from model.base import Base


@pytest.fixture(scope="session")
def database_url() -> str:
    return "mysql+aiomysql://root:password@localhost:3306/aiodev"


@pytest.fixture(scope="session")
def engine(database_url: str) -> AsyncEngine:
    return create_async_engine(database_url, echo=True, poolclass=NullPool)


@pytest_asyncio.fixture(autouse=True, scope="session")
async def db_create_drop(engine: AsyncEngine, database_url: str) -> AsyncGenerator[AsyncEngine, None]:
    """DBの作成と削除"""
    parsed_url = make_url(database_url)
    database_name = str(parsed_url.database)
    database_engine = create_async_engine(parsed_url._replace(database=None), poolclass=NullPool)
    async with database_engine.begin() as conn:
        sql = f"CREATE DATABASE IF NOT EXISTS {database_name} CHARACTER SET = 'utf8mb4'" # 文字コードもURLからちゃんと取るようにすると偉い
        await conn.execute(text(sql))

    async with engine.connect() as conn:
        await conn.run_sync(Base.metadata.create_all)

    yield engine

    # DROPするのでここは正直必要ない
    async with engine.connect() as conn:
        await conn.run_sync(Base.metadata.drop_all)

    async with database_engine.begin() as conn:
        sql = f"DROP DATABASE IF EXISTS {database_name}"
        await conn.execute(text(sql))


async def _truncate_all_tables(engine: AsyncEngine) -> None:
    """全テーブルTRUNCATEする"""
    async with engine.connect() as conn:
        await conn.execute(text("SET FOREIGN_KEY_CHECKS = 0;"))
        for table in Base.metadata.sorted_tables:
            await conn.execute(text(f"TRUNCATE TABLE {table.name};"))
        await conn.execute(text("SET FOREIGN_KEY_CHECKS = 1;"))


@pytest_asyncio.fixture()
async def session(engine: AsyncEngine) -> AsyncSession:
    """DBセッションを返す"""
    await _truncate_all_tables(engine) # テストケース毎に全テーブルTRUNCATEしたい
    return AsyncSession(engine)

これで session fixture を通してDBにアクセスするようにテストコードを書けば、 aiomysql を使って毎回テストケース毎に TRUNCATE するテストを書くことができます。

まとめ

他にこうしているよとか、このほうがいいよとあれば是非コメントで教えてください。

Discussion