🫣

FastAPIのDB接続とマイグレーション(DIコンテナも準備)

2024/05/17に公開

最終ディレクトリ構成


かなり急いで書いたのでミスなどあったらコメントください!
今回は前回に続きとしてDB接続(テスト用は後に作成しますが、下準備はします)とマイグレーションのup,downまでとなります。
前回: https://zenn.dev/momonga_g/articles/f131ea192b1184

今回の最終系のディレクトリ構成

project_root
├── _docker
│   ├── nginx
│   │   └── nginx.conf
│   └── python
│       └── Dockerfile
├── src
│   └── main.py
│   └── core
│        └── dependency.py
│   └── model
│        └── user.py
│   └── database // 追加 alembic initの生成物を格納
│        └── versions // マイグレーションファイルがここに生成される(生成物)
│        └── env.py (alembic生成物)
│        └── script.py.mako (alembic生成物)
│        └── database.py // DB設定ファイル
├── .env
├── pyproject.toml
├── poetry.lock
├── makefile
├── makefile.local
├── makefile.container
├── .gitignore // git使用しているなら追加しましょう
└── docker-compose.yml
└── alembic.ini (alembic生成物)

下記をインストールします

sqlalchemyでもいいですが、SQLModelならモデル管理が楽なのでこちらを採用します。

SQLAlchemyPydanticの両方の互換性を持つため、SQLAlchemyとPydanticの2重モデル管理を解消してくれます。(筆者も開発中にこれが嫌で途中からSQLModelに切り替えました。)

 poetry add alembic sqlmodel aiomysql python-dotenv injector

なんかPythonってドキュメントが読みにくいライブラリが多い気がします。。。

気のせいかな?

DIコンテナの準備をする

FastAPIにも依存性注入する機能が存在しますが、今回はInjectorをメインにして、FastAPIの依存性注入は限定的に使用することにします。(FastAPIのお作法とは違うのでそこはご留意ください。)

まずはDIコンテナをシングルトンパターンで作成します。

生成されたインスタンスはグローバル変数として定義しておきます。

update_injectorメソッドでDBを環境によって切り替えることが可能になります。

基本はget_classメソッドでクラスを取得することが可能です。

※_initializeメソッドのAppConfigクラスはDBエンジン作成時に作成します。

/src/core/dependency.py

from injector import Injector, inject
from src.database import AppConfig

class DependencyInjector:
    _instance = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(DependencyInjector, cls).__new__(cls)
            cls._instance._initialize()
        return cls._instance

    def _initialize(self):
        self.di = Injector([AppConfig()])

    async def update_injector(self, _class):
        self.di = Injector([_class])

    @inject
    def get_class(self, _class):
        return self.di.get(_class)

di_injector = DependencyInjector()

DBエンジンの作成

/src/database/database.py

順に説明していきます。

import asyncio
import os
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager

from dotenv import load_dotenv
from injector import Module, provider, singleton
from sqlalchemy.ext.asyncio import AsyncEngine, async_scoped_session, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel.ext.asyncio.session import AsyncSession

load_dotenv()

class DatabaseConnection:
    @singleton
    def __init__(self, connection_url: str, migration_url: str, option: dict = {}):
        self.connection_url = connection_url
        self.migration_url = migration_url
        self.option = option
        self.engine = self.get_async_engine()
        self.session = self.get_session(self.engine)

    @asynccontextmanager
    async def get_db(self):
        async with self.session() as session:
            yield session

    async def close_engine(self):
        if self.engine:
            await self.engine.dispose()
            await self.session.close()
            self.engine = None
            self.session = None

    def get_url(self) -> str:
        return self.connection_url

    def get_migration_url(self) -> str:
        return self.migration_url

    def get_async_engine(self) -> AsyncEngine:
        return create_async_engine(self.connection_url, **self.option)

    def get_session(self, engine: AsyncEngine) -> AsyncSession:
        async_session_factory = sessionmaker(
            autocommit=False,
            autoflush=False,
            bind=engine,
            class_=AsyncSession,
            expire_on_commit=True,
        )
        # セッションのスコープ設定
        return async_scoped_session(async_session_factory, scopefunc=asyncio.current_task)

# DB設定クラスのインターフェース
class ConfigInterface(ABC):
    @abstractmethod
    def db_url(self) -> str:
        pass

    @abstractmethod
    def migration_url(self) -> str:
        pass

    @abstractmethod
    def get_option(self) -> dict:
        pass

class AppConfig(Module, ConfigInterface):
    @singleton
    @provider
    def provide_database_connection(self) -> DatabaseConnection:
        return DatabaseConnection(self.db_url(), self.migration_url(), self.get_option())

    def db_url(self) -> str:
        dialect = os.getenv("DB_DIALECT")
        driver = os.getenv("DB_DRIVER")
        username = os.getenv("DB_USER")
        password = os.getenv("DB_PASS")
        host = os.getenv("DB_HOST")
        port = os.getenv("DB_PORT")
        db_name = os.getenv("DB_NAME")
        return f"{dialect}+{driver}://{username}:{password}@{host}:{port}/{db_name}?charset=utf8"

    def migration_url(self) -> str:
        dialect = os.getenv("DB_DIALECT")
        username = os.getenv("DB_USER")
        password = os.getenv("DB_PASS")
        host = os.getenv("DB_HOST")
        port = os.getenv("DB_PORT")
        db_name = os.getenv("DB_NAME")
        return f"{dialect}+pymysql://{username}:{password}@{host}:{port}/{db_name}?charset=utf8"

    def get_option(self):
        logging = bool(os.getenv("SQL_LOGGING"))
        pool_size = int(os.getenv("DB_POOL_SIZE"))
        pool_connection_timeout = int(os.getenv("POOL_CONN_TIMEOUT"))
        max_overflow = int(os.getenv("DB_MAX_OVERFLOW"))
        pool_recycle = int(os.getenv("POOL_RECYCLE"))
        return {
            "echo": logging,
            "echo_pool": logging,
            "pool_size": pool_size,
            "pool_timeout": pool_connection_timeout,
            "max_overflow": max_overflow,
            "pool_recycle": pool_recycle,
            "pool_pre_ping": True,
        }

ここで重要なのが、AppConfigです。

AppConfigは下記の部分でDIコンテナにDatabaseConnectionクラスを登録します。

def provide_database_connection(self) -> DatabaseConnection:
	return DatabaseConnection(self.db_url(), self.migration_url(), self.get_option())

DependencyInjectorクラスの_initializeメソッドでこのAppConfigを通して、DatabaseConnectionを登録しています。

class DependencyInjector:
    _instance = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(DependencyInjector, cls).__new__(cls)
            cls._instance._initialize()
        return cls._instance

    def _initialize(self):
        self.di = Injector([AppConfig()]) // ここでAppConfigからDatabaseConnectionクラスをDIコンテナに登録している

なぜ、設定を分けてわざわざDIコンテナに登録しているのかというと、環境によって接続先を変更したいケースがあるからです。

だからこそ、ConfigInterfaceというインターフェースを用意して、あるべきメソッドを束縛しています。

# DB設定クラスのインターフェース
class ConfigInterface(ABC):
    @abstractmethod
    def db_url(self) -> str:
        pass

    @abstractmethod
    def migration_url(self) -> str:
        pass

    @abstractmethod
    def get_option(self) -> dict:
        pass

例えば、テスト環境を分けたい場合はTestAppConfigを作成して、conftestでAppConfigからTestAppConfigへ差し替えなども可能になります。(テストの部分で記載予定)

class TestAppConfig(Module, ConfigInterface):
    @singleton
    @provider
    def provide_database_connection(self) -> DatabaseConnection:
        return DatabaseConnection(self.db_url(), self.migration_url(), self.get_option())

    def db_url(self) -> str:
        return f"sqlite+aiosqlite:///./test.db"

    def migration_url(self) -> str:
        return f"sqlite+aiosqlite:///./test.db"

    def get_option(self) -> dict:
        return {}

さて!ではDatabaseConnectionを見ていきます。(重要なところだけ抜粋します)

と言っても難しいことはしていなくて、コンストラクタでAppConfigから渡された引数を使用して、エンジンとセッションを作成しているだけです。

class DatabaseConnection:
    @singleton
    def __init__(self, connection_url: str, migration_url: str, option: dict = {}):
        self.connection_url = connection_url
        self.migration_url = migration_url
        self.option = option
        self.engine = self.get_async_engine()
        self.session = self.get_session(self.engine)
        
    @asynccontextmanager
    async def get_db(self):
        async with self.session() as session:
            yield sessio

    def get_async_engine(self) -> AsyncEngine:
        return create_async_engine(self.connection_url, **self.option)

    def get_session(self, engine: AsyncEngine) -> AsyncSession:
        async_session_factory = sessionmaker(
            autocommit=False,
            autoflush=False,
            bind=engine,
            class_=AsyncSession,
            expire_on_commit=True,
        )
        # セッションのスコープ設定
        return async_scoped_session(async_session_factory, scopefunc=asyncio.current_task)

get_async_engineメソッドでで非同期エンジンを作成します。

get_sessionメソッドで

async_session_factoryを使って新しい非同期セッションを生成します。

async_scoped_sessionを使用して、セッションファクトリーに基づいたスコープセッションを作成します。scopefunc=asyncio.current_taskを指定することで、現在の非同期タスクごとに個別のセッションを管理します。

要は非同期タスクごとにセッションを管理できるようになります。

.envを定義する

実は/src/database/database.pyではload_dotenvを使用して、.envを読み込んでいました。

なので、.envを記載していきましょう。

/.env

APP_ENV=development
BASE_URL=http://localhost:8000
# 開発DB
DB_DIALECT=mysql # DBの種類
DB_DRIVER=aiomysql # DBのドライバー
DB_NAME=fastapi_db # docker-compose.ymlに合わせる
DB_USER=user # docker-compose.ymlに合わせる
DB_PASS=password # docker-compose.ymlに合わせる
DB_HOST=mysql # docker-compose.ymlに合わせる(サービス名)
DB_PORT=3306
DB_POOL_SIZE=10
DB_MAX_OVERFLOW=-1 # poolのmax_overflowを設定しない場合は-1を設定
POOL_CONN_TIMEOUT=10 #コネクションプールから接続を取得する際のタイムアウト(秒)を設定
POOL_RECYCLE=3600 #接続をプールに戻す前に再利用する最大時間(秒)を設定。これは、データベース接続が古くなるのを防ぐために使用
SQL_LOGGING=True

これで一旦はDBの設定はOKです。

マイグレーションの準備

今、ファイルはこんな感じになってるかと思います。

project_root
├── _docker
│   ├── nginx
│   │   └── nginx.conf
│   └── python
│       └── Dockerfile
├── src
│   └── main.py
│   └── core
│        └── dependency.py
│   └── database
│        └── database.py // DB設定ファイル
├── .env
├── pyproject.toml
├── poetry.lock
├── makefile
├── makefile.local
├── makefile.container
├── .gitignore // git使用しているなら追加しましょう
└── docker-compose.yml

databaseディレクトリに移って下記を実行してください。

alembic init

下記が生成されるかと思います。

  • versionsディレクトリ
    • マイグレーションファイルが格納されます。(まだからのはず)
  • alembic.ini
    • 設定ファイル
  • env.py
    • マイグレーションが実こうされた時に必ず実行されるスクリプト
  • script.py.mako
    • マイグレーションのテンプレート
  • readme(無視でOKです)

alembic.iniだけルートに移動しておいてください。makeコマンドを実行する上で、ルートにいないと使用できないため。

/alembic.ini

# A generic, single database configuration.

[alembic]
# path to migration scripts
script_location = src/database/ <-書き換えてください

env.py

マイグレーション用のURLを設定

from logging.config import fileConfig

from sqlalchemy import engine_from_config
from sqlalchemy import pool
from sqlmodel import SQLModel

from alembic import context

from src.database.database import DatabaseConnection
from src.core.dependency import di_injector
from src.model.user import Users // モデル作成後に必ずここに追記してください。

db_connection = di_injector.get_class(DatabaseConnection) // 追加
ASYNC_DB_URL = db_connection.get_migration_url() // 追加

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config

# alembic.iniの設定を上書き
config.set_main_option('sqlalchemy.url', ASYNC_DB_URL) // 追加

# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
    fileConfig(config.config_file_name)

# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = SQLModel.metadata // 変更

# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.

def run_migrations_offline() -> None:
    """Run migrations in 'offline' mode.

    This configures the context with just a URL
    and not an Engine, though an Engine is acceptable
    here as well.  By skipping the Engine creation
    we don't even need a DBAPI to be available.

    Calls to context.execute() here emit the given string to the
    script output.

    """
    url = config.get_main_option("sqlalchemy.url")
    context.configure(
        url=url,
        target_metadata=target_metadata,
        literal_binds=True,
        dialect_opts={"paramstyle": "named"},
    )

    with context.begin_transaction():
        context.run_migrations()

def run_migrations_online() -> None:
    """Run migrations in 'online' mode.

    In this scenario we need to create an Engine
    and associate a connection with the context.

    """
    connectable = engine_from_config(
        config.get_section(config.config_ini_section, {}),
        prefix="sqlalchemy.",
        poolclass=pool.NullPool,
    )

    with connectable.connect() as connection:
        context.configure(
            connection=connection, target_metadata=target_metadata
        )

        with context.begin_transaction():
            context.run_migrations()

if context.is_offline_mode():
    run_migrations_offline()
else:
    run_migrations_online()

モデル作成

/src/model/user

SQLModelを継承します。かつtable=Trueにしてテーブル対象のモデルにします。(これ忘れるとマイグレーションに含まれないのでお気をつけください。)

from datetime import datetime, timezone
from typing import Optional
from uuid import uuid4
from sqlalchemy import TIMESTAMP, Column, Integer, String
from sqlmodel import Field, SQLModel

class Users(SQLModel, table=True):
    id: Optional[int] = Field(default=None, sa_column=Column(Integer, primary_key=True, comment="ID"))
    uuid: str = Field(
        default_factory=lambda: str(uuid4()), sa_column=Column(String(36), nullable=False, unique=True, comment="UUID")
    )
    account_name: str = Field(sa_column=Column(String(100), nullable=False, comment="アカウント名"))
    email: str = Field(sa_column=Column(String(100), nullable=False, unique=True, comment="メールアドレス"))
    hashed_password: str = Field(sa_column=Column(String(100), nullable=False, comment="パスワード"))
    created_at: datetime = Field(
        default_factory=lambda: datetime.now(timezone.utc),
        sa_column=Column(
            TIMESTAMP(True),
            nullable=True,
            default=datetime.now(timezone.utc)
        )
    )
    updated_at: datetime = Field(
        default_factory=lambda: datetime.now(timezone.utc),
        sa_column=Column(
            TIMESTAMP(True),
            nullable=True,
            onupdate=datetime.now(timezone.utc)
        )
    )

makeファイルにもコマンドを追加しておきます。

src/makefile.container

(docker compose exec -it fastapi alembic ~でmakefile.localに追加してもいいかもですね)

.PHONY: help print
.DEFAULT_GOAL := help

print: ## 分岐テスト用
	echo "コンテナ"

migration: ## Run alembic migration
	alembic upgrade head

migration-rollback: ## Run alembic migration
	alembic downgrade -1

migration-refresh: ## Run alembic migration
	alembic downgrade base
	alembic upgrade head

create-migration: ## Create alembic migration
ifndef name
	$(error name is not set)
endif
	alembic revision --autogenerate -m "$(name)"

ここまでできたらコマンドを実行します。

make create-migration name=create_user
alembic revision --autogenerate -m "create_user"
INFO  [alembic.runtime.migration] Context impl MySQLImpl.
INFO  [alembic.runtime.migration] Will assume non-transactional DDL.
INFO  [alembic.autogenerate.compare] Detected added table 'users'
  Generating /src/src/database/versions/d8181ecb6ec8_create_user.py ...  done

src/database/versionにマイグレーションファイルが生成されれば成功です!

"""create_user

Revision ID: d8181ecb6ec8
Revises: 
Create Date: 2024-05-16 15:27:32.123787

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa

# revision identifiers, used by Alembic.
revision: str = 'd8181ecb6ec8'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None

def upgrade() -> None:
    # ### commands auto generated by Alembic - please adjust! ###
    op.create_table('users',
    sa.Column('id', sa.Integer(), nullable=False, comment='ID'),
    sa.Column('uuid', sa.String(length=36), nullable=False, comment='UUID'),
    sa.Column('account_name', sa.String(length=100), nullable=False, comment='アカウント名'),
    sa.Column('email', sa.String(length=100), nullable=False, comment='メールアドレス'),
    sa.Column('hashed_password', sa.String(length=100), nullable=False, comment='パスワード'),
    sa.Column('created_at', sa.TIMESTAMP(timezone=True), nullable=True),
    sa.Column('updated_at', sa.TIMESTAMP(timezone=True), nullable=True),
    sa.PrimaryKeyConstraint('id'),
    sa.UniqueConstraint('email'),
    sa.UniqueConstraint('uuid')
    )
    # ### end Alembic commands ###

def downgrade() -> None:
    # ### commands auto generated by Alembic - please adjust! ###
    op.drop_table('users')
    # ### end Alembic commands ###

ではマイグレーションしてみましょう。

make migration
alembic upgrade head
INFO  [alembic.runtime.migration] Context impl MySQLImpl.
INFO  [alembic.runtime.migration] Will assume non-transactional DDL.
INFO  [alembic.runtime.migration] Running upgrade  -> d8181ecb6ec8, create_user

テーブルができているはずです。

ロールバックも行ってみましょう

make migration-rollback
alembic downgrade -1
INFO  [alembic.runtime.migration] Context impl MySQLImpl.
INFO  [alembic.runtime.migration] Will assume non-transactional DDL.
INFO  [alembic.runtime.migration] Running downgrade d8181ecb6ec8 -> , create_user

テーブルが消えているはずです!

お疲れ様でした!

とりあえず、ここまででマイグレーションは終わりたいと思います。

次回はログあたりをやろうと考えています。

Discussion