😸

SQLAlchemy で MySQL の UPSERT を実装する

2023/11/16に公開

概要

Rails だと以前から activerecord-import っていう gem で Upsert が実現できたり、6.x や 7.x からは ActiveRecord が upsert_all を実装している。

Python が採用されているあるプロジェクトでは ORM に SQLAlchemy を採用しているが、これが Rails の ActiveRecord のように抽象化された UPSERT は行えず、各RDB毎に用意されたインターフェースを用いて適当に実装する必要があった。

ここでは MySQL を利用した実装例を記載する。

動作確認環境

ライブラリ バージョン
sqlalchemy 2.0.23
pymysql 1.1.0
ミドルウェア バージョン
MySQL 8.0.23

実装例

appname/orm/database.py のように名前空間を切って以下のDBとのインターフェースの初期化などを行うモジュールを配置する。

import os

from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker

engine = create_engine(
    os.environ.get(
        "DATABASE_URL", "mysql+pymysql://root:password@127.0.0.1:33306/appname?charset=utf8mb4"
    )
)

Base = declarative_base()

session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory)

appname/orm/strategies/upsert.py では実際の UPSERT を実装していく。

import os
from typing import Generic, TypeVar

import sqlalchemy.dialects.mysql as mysql

from appname.orm.database import Base, Session

T = TypeVar("T", bound=Base)


def create_upsert_strategy():
    if os.environ.get("DB", "mysql") == "mysql":
        return _mysql_upsert_strategy
    else:
        raise NotImplementedError


def _mysql_upsert_strategy(
    session: Session, model: Generic[T], values: list[dict], on_duplicate_key_update: list[str]
) -> bool:
    upsert_statement = mysql.insert(model).values(values)

    update_dict = {field: upsert_statement.inserted[field] for field in on_duplicate_key_update}

    on_duplicate_key_statement = upsert_statement.on_duplicate_key_update(**update_dict)

    session.execute(on_duplicate_key_statement)
    session.commit()

    return True

※ エラーハンドリングは呼び出し元で実装することを想定

補足

  • ストラテジーパターンを使用し、MySQL 以外のリレーショナルデータベース(例えば PostgreSQL)にも対応できる柔軟性を持たせている
  • 開放閉鎖原則に基づき、新しいデータベースに対応する際はストラテジを追加するだけで済み、拡張性を高めている
  • 型制約の使用により Base を継承したクラスを処理対象とし、安全性を高めている

まとめ

まとめというよりは所感になるが、各RDB毎の実装自体は SQLAlchemy にあるので対応はしやすかった。ほとんど実装コストもかからず実現できた。

追記

PostgreSQL のストラテジを追加してみた。

import os
from typing import Any, Callable, Type

import sqlalchemy.dialects.mysql as mysql
import sqlalchemy.dialects.postgresql as postgresql
from sqlalchemy import inspect
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.orm import Session as SQLAlchemySession


def create_upsert_strategy() -> (
    Callable[
        [
            SQLAlchemySession,
            Type[DeclarativeMeta],
            list[dict[str, Any]],
            list[str],
        ],
        bool,
    ]
):
    if os.environ.get("DB", "postgresql") == "mysql":
        return _mysql_upsert_strategy
    else:
        return _postgresql_upsert_strategy


def _mysql_upsert_strategy(
    session: SQLAlchemySession,
    model: Type[DeclarativeMeta],
    values: list[dict[str, Any]],
    on_duplicate_key_update: list[str],
) -> bool:
    try:
        upsert_statement = mysql.insert(model).values(values)
        update_dict = {field: upsert_statement.inserted[field] for field in on_duplicate_key_update}
        on_duplicate_key_statement = upsert_statement.on_duplicate_key_update(**update_dict)
        session.execute(on_duplicate_key_statement)
        session.commit()

        return True
    except Exception as e:
        session.rollback()
        raise e


def _postgresql_upsert_strategy(
    session: SQLAlchemySession,
    model: Type[DeclarativeMeta],
    values: list[dict[str, Any]],
    on_duplicate_key_update: list[str],
) -> bool:
    try:
        index_elements = _get_primary_keys(model)
        for value in values:
            upsert_statement = postgresql.insert(model).values(value)
            update_dict = {
                field: upsert_statement.excluded[field] for field in on_duplicate_key_update
            }
            on_conflict_statement = upsert_statement.on_conflict_do_update(
                index_elements=index_elements, set_=update_dict
            )
            session.execute(on_conflict_statement)
        session.commit()

        return True
    except Exception as e:
        session.rollback()
        raise e


def _get_primary_keys(model: Type[DeclarativeMeta]) -> list[str]:
    """
    指定された SQLAlchemy モデルから主キーのカラム名のリストを返す。
    :param model: SQLAlchemy モデルクラス
    :return: 主キーのカラム名のリスト
    """
    mapper = inspect(model)
    if mapper is None:
        raise ValueError("Model has no mapper or is not a valid SQLAlchemy model.")
    return [key.name for key in mapper.primary_key]

※ このケースではエラーハンドリングも関数内で対応

  • PostgreSQL のストラテジで使ってる on_conflict_do_update メソッドでは、 MySQL のストラテジで使っている on_duplicate_key_update メソッドと違って自動的にユニークキーまたはプライマリーキーを基準にして衝突を回避する機能が提供ない
  • 前項の理由から index_elements パラメータを明示的に指定する必要がある
  • とはいえ、インターフェースを統一したいので inspect で SQLAlchemy のモデルから動的に主キーを取得して index_elements として利用する値を導出

参考

https://github.com/sqlalchemy/sqlalchemy/discussions/9328

Discussion