🌊

pytest+sqlalchemy+mysqlでループ内で失敗させてロールバックのテストをする

2023/07/17に公開

やりたいこと

ループ途中でDB処理が失敗したときのロールバック挙動をテストする

準備

(my環境)
Python 3.9.13
sqlalchemy 2.0.17
mysql 8.0.33

MySQLのインストール、準備

インストール

sudo apt-get install mysql-server
sudo mysql_secure_installation

ログイン

sudo mysql -u root -p

dbnameというデータベース、 username, passwordでユーザを作成する。

CREATE DATABASE dbname;

CREATE USER 'username'@'localhost' IDENTIFIED BY 'password';
GRANT ALL PRIVILEGES ON dbname.* TO 'username'@'localhost';
FLUSH PRIVILEGES;

userテーブルを作る。

USE dbname;
CREATE TABLE user (
      id INT,
      name VARCHAR(100),
      PRIMARY KEY (id)
);

exit で終了

Pythonモジュールをインストール

pip install sqlalchemy
pip install pymysql
pip install cryptography

(cryptographyはMYSQLの認証で使うため)

コード

rollback.py
from sqlalchemy import MetaData, Table, create_engine
from sqlalchemy.orm import sessionmaker


class MyClass:
    def print_user_select(message, user, session):
        result = session.execute(user.select())
        print(message, [row[0] for row in result])

    @classmethod
    def insert_id(cls, i, user, session):
        """Insert id into user table and print current ids."""
        print(f"insert_id() i={i}")
        insert_stmt = user.insert().values(id=i, name=f"user{i}")
        session.execute(insert_stmt)
        cls.print_user_select("After insert", user, session)

    @classmethod
    def trigger_fail(cls, i):
        """Dummy for trigger fail"""
        print(f"trigger_fail() i={i}")

    @classmethod
    def main(cls):
        """Insert three ids into user table, rolling back on exception."""
        engine = create_engine("mysql+pymysql://username:password@localhost/dbname")
        Session = sessionmaker(bind=engine)
        session = Session()

        metadata = MetaData()
        user = Table("user", metadata, autoload_with=engine)

        for i in range(3):
            print(f"--- i={i} try.")
            try:
                cls.insert_id(i, user, session)
                cls.trigger_fail(i)
                session.commit()
            except Exception as e:
                print("Exception message", e)
                cls.print_user_select("Before rollback", user, session)
                session.rollback()
                print("Rollback!")
                cls.print_user_select("After rollback", user, session)


if __name__ == "__main__":
    MyClass.main()
test_rollback.py
from unittest.mock import patch

import pytest
from rollback import MyClass
from sqlalchemy import MetaData, Table, create_engine
from sqlalchemy.orm import scoped_session, sessionmaker

engine = create_engine("mysql+pymysql://username:password@localhost/dbname")
session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory)

metadata = MetaData()
user = Table("user", metadata, autoload_with=engine)


@pytest.fixture(scope="session", autouse=True)
def setup_db():
    """Set up database by deleting all records from the 'user' table."""
    session = Session()
    session.execute(user.delete())
    session.commit()
    Session.remove()


# Save original function
original_trigger_fail = MyClass.trigger_fail


def mock_for_fail(i):
    """Mock throwing an exception when i == 2."""
    print(f"mock_insert_id called with i={i}")
    if i == 2:
        raise Exception("failed")
    original_trigger_fail(i)


def test_insert():
    """Test that 3rd insertion fails and gets rolled back."""
    with patch.object(MyClass, "trigger_fail", new=mock_for_fail):
        session = Session()
        myclass = MyClass()
        myclass.main()

        result = session.execute(user.select())
        users = result.fetchall()
        print("End of test", [user[0] for user in users])
        assert len(users) == 2

テストを実行する

やっていること

  • trigger_fail()ユニットテストで失敗させるためだけに作っている関数なので、実際はinsert_id()を使ってもよいかもしれません。(ただし、そのときはロールバック前はinsertせずに落ちるのでidが増えているのは見えません)
  • with patch.object(MyClass, "trigger_fail", new=mock_for_fail)でtrigger_failをMockに差し替えます。
  • mock_for_fail(i)で3回目(i=2)でraise Exceptionします。それ以外はオリジナルが使われます。

実行

pytest -s test_rollback.py

結果
ロールバック前後でIDが2つになっていて、最終的にも2つとなります。

--- i=0 try.
insert_id() i=0
After insert [0]
mock_insert_id called with i=0
trigger_fail() i=0
--- i=1 try.
insert_id() i=1
After insert [0, 1]
mock_insert_id called with i=1
trigger_fail() i=1
--- i=2 try.
insert_id() i=2
After insert [0, 1, 2]
mock_insert_id called with i=2
Exception message failed
Before rollback [0, 1, 2]
Rollback!
After rollback [0, 1]
test_insert: Current IDs in table:  [0, 1]

まとめ

  • pytestで、ループ途中でDBロールバックするときの挙動をテストしました
    • いまいちこれがベストなのか分からないですが、所望の動作にはなっているのでOK

Discussion