🌊
pytest+sqlalchemy+mysqlでループ内で失敗させてロールバックのテストをする
やりたいこと
ループ途中でDB処理が失敗したときのロールバック挙動をテストする
準備
(my環境)
Python 3.9.13
sqlalchemy 2.0.17
mysql 8.0.33
MySQLのインストール、準備
インストール
sudo apt-get install mysql-server
sudo mysql_secure_installation
ログイン
sudo mysql -u root -p
dbname
というデータベース、 username
, password
でユーザを作成する。
CREATE DATABASE dbname;
CREATE USER 'username'@'localhost' IDENTIFIED BY 'password';
GRANT ALL PRIVILEGES ON dbname.* TO 'username'@'localhost';
FLUSH PRIVILEGES;
user
テーブルを作る。
USE dbname;
CREATE TABLE user (
id INT,
name VARCHAR(100),
PRIMARY KEY (id)
);
exit
で終了
Pythonモジュールをインストール
pip install sqlalchemy
pip install pymysql
pip install cryptography
(cryptographyはMYSQLの認証で使うため)
コード
rollback.py
from sqlalchemy import MetaData, Table, create_engine
from sqlalchemy.orm import sessionmaker
class MyClass:
def print_user_select(message, user, session):
result = session.execute(user.select())
print(message, [row[0] for row in result])
@classmethod
def insert_id(cls, i, user, session):
"""Insert id into user table and print current ids."""
print(f"insert_id() i={i}")
insert_stmt = user.insert().values(id=i, name=f"user{i}")
session.execute(insert_stmt)
cls.print_user_select("After insert", user, session)
@classmethod
def trigger_fail(cls, i):
"""Dummy for trigger fail"""
print(f"trigger_fail() i={i}")
@classmethod
def main(cls):
"""Insert three ids into user table, rolling back on exception."""
engine = create_engine("mysql+pymysql://username:password@localhost/dbname")
Session = sessionmaker(bind=engine)
session = Session()
metadata = MetaData()
user = Table("user", metadata, autoload_with=engine)
for i in range(3):
print(f"--- i={i} try.")
try:
cls.insert_id(i, user, session)
cls.trigger_fail(i)
session.commit()
except Exception as e:
print("Exception message", e)
cls.print_user_select("Before rollback", user, session)
session.rollback()
print("Rollback!")
cls.print_user_select("After rollback", user, session)
if __name__ == "__main__":
MyClass.main()
test_rollback.py
from unittest.mock import patch
import pytest
from rollback import MyClass
from sqlalchemy import MetaData, Table, create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
engine = create_engine("mysql+pymysql://username:password@localhost/dbname")
session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory)
metadata = MetaData()
user = Table("user", metadata, autoload_with=engine)
@pytest.fixture(scope="session", autouse=True)
def setup_db():
"""Set up database by deleting all records from the 'user' table."""
session = Session()
session.execute(user.delete())
session.commit()
Session.remove()
# Save original function
original_trigger_fail = MyClass.trigger_fail
def mock_for_fail(i):
"""Mock throwing an exception when i == 2."""
print(f"mock_insert_id called with i={i}")
if i == 2:
raise Exception("failed")
original_trigger_fail(i)
def test_insert():
"""Test that 3rd insertion fails and gets rolled back."""
with patch.object(MyClass, "trigger_fail", new=mock_for_fail):
session = Session()
myclass = MyClass()
myclass.main()
result = session.execute(user.select())
users = result.fetchall()
print("End of test", [user[0] for user in users])
assert len(users) == 2
テストを実行する
やっていること
-
trigger_fail()
はユニットテストで失敗させるためだけに作っている関数なので、実際はinsert_id()
を使ってもよいかもしれません。(ただし、そのときはロールバック前はinsertせずに落ちるのでidが増えているのは見えません) -
with patch.object(MyClass, "trigger_fail", new=mock_for_fail)
でtrigger_failをMockに差し替えます。 -
mock_for_fail(i)
で3回目(i=2)でraise Exceptionします。それ以外はオリジナルが使われます。
実行
pytest -s test_rollback.py
結果
ロールバック前後でIDが2つになっていて、最終的にも2つとなります。
--- i=0 try.
insert_id() i=0
After insert [0]
mock_insert_id called with i=0
trigger_fail() i=0
--- i=1 try.
insert_id() i=1
After insert [0, 1]
mock_insert_id called with i=1
trigger_fail() i=1
--- i=2 try.
insert_id() i=2
After insert [0, 1, 2]
mock_insert_id called with i=2
Exception message failed
Before rollback [0, 1, 2]
Rollback!
After rollback [0, 1]
test_insert: Current IDs in table: [0, 1]
まとめ
- pytestで、ループ途中でDBロールバックするときの挙動をテストしました
- いまいちこれがベストなのか分からないですが、所望の動作にはなっているのでOK
Discussion