🍺

pytestのmark.parametrizeでテストのロジックとデータを分離しよう

2024/04/30に公開

pytest.mark.parametrizeとは

pytestのparametrizeデコレータは、pythonのpytestで、データをパラメータ化してくれるツールです。かなり使えるツールなのですが、テストのロジックとデータが一緒になりがちでチョット使いにくい感じがしています。

ネットでよく見る使い方

例えばこんな関数があったとします。

target.py
def calculate_numbers(
    num1: int,
    num2: int,
    operation: str = 'add'
) -> int:
    """二つの数値と操作を指定し、加算または乗算を行う"""
    if not isinstance(num1, int) or not isinstance(num2, int):
        raise TypeError("Both num1 and num2 must be int")
    if operation == 'add':
        return num1 + num2
    elif operation == 'multiply':
        return num1 * num2
    else:
        raise ValueError("Unsupported operation")

この関数をテストするとき、ネットでよく見る @pytest.mark.parametrizeを使ったテストの書き方って、こんな感じでしょうか。

test.py
import pytest
from target import calculate_numbers  # 上記の calculate_numbers 関数を含むモジュールをインポート

@pytest.mark.parametrize("num1, num2, operation, expected", [
    [10, 5, "add", 15],
    [3, 7, "multiply", 21],
    [-1, -1, "add", -2],
])
def test_calculate_numbers(num1, num2, operation, expected):
    """ 正常系のテスト """
    resp = calculate_numbers(num1, num2, operation=operation)
    assert resp == expected

@pytest.mark.parametrize("num1, num2, operation", [
    ["ten", 5, "add"],
    [3, "seven", "multiply"],
    [1, 2, "divide"],
])
def test_calculate_numbers_errors2(num1, num2, operation):
    """ 異常系のテスト """
    with pytest.raises((TypeError, ValueError)):
        calculate_numbers(num1, num2, operation=operation)

問題点

とてもシンプルだとは思うのですが、何か物足りない...

私の独断と偏見にもとづけば、不満点はこんな感じ

  • デコレータ内にデータを記述するので、テスト関数のロジックとデータを完全に分離できない
  • データの説明がないので、なんのテストか分かりにくい
  • 引数がどのデータに対応しているか分かりにくい
  • デコレータの第1引数とテスト関数の引数を合わせなければならない点など、微妙に修正する箇所があるので使いまわしにくい
  • 正常系と異常系も一つのテストにまとめたい
  • キーワード引数なしのテストは正常系とは別にしないといけない(正常系の3番目のデータをNoneとかにしてもエラーになります)

解決策

んなことを色々考えてゴニョゴニョ修正していくと、こんな形に落ち着きました。

def get_params_func():
    _data = [
        {
            "description": "..."        # データの説明を入れると管理しやすくなる
            "args": {                   # 引数をargsにまとめる
                "num1": ...,            # 第1引数のテスト値
                "num2": ...,            # 第2引数のテスト値
            },
            "kwargs": {                 # キーワード引数をkwargsにまとめる
                "operation": "...",     # 使わない場合は、"kwargs": {}とする
            },
            "expected": {
                "resp": ...,            # テスト対象関数の返り値の期待値
                "error": ...,           # エラーが発生する場合は、ロジックでif文を入れる
            },
        }
    ]
    return list(_data[0].keys()), [
        (
            d["description"],
            d["args"].values(),         # ここだけkeyを使用しないので注意
            d["kwargs"],
            d["expected"],
        )
        for d in _data
    ]

@pytest.mark.parametrize(*get_params_func())    # アンパックで対応できるので簡単
def test_func(description, args, kwargs, expected):
    # テスト実行のロジック...
    print(f"-- {description} --")               # 折角なので、テスト結果にも表示
    resp = func(*args, **kwargs)                # アンパックで対応できるので簡単
    assert resp == expected["resp"]

ポイント

  • データは、get_params_func関数の_dataにlist[dict]型のデータとしてまとめて見やすくした
  • get_params_func関数のOutputは、アンパック演算子(*)でpytest.mark.parametrizeの引数にそのまま渡せるようにした(#)
  • 引数をargs,kwargsでまとめて、アンパック演算子(*, **)でtest関数に簡単に引き渡せるようにした
  • expectedのデータを工夫すれば正常系と異常系のテストを一つにすることができる(サンプル参照)

補足説明

(#)については、補足説明が必要だと思いますので、ちょいと説明

@pytest.mark.parametrizeの第1引数は文字列になっていますが、これは配列でも可能です。

@pytest.mark.parametrize("description, args, kwargs, expected", [...])

このように書くこともできます。

@pytest.mark.parametrize(["description", "args", "kwargs", "expected"], [...])

なので、get_params_func関数の返り値の第1引数をlist[_data[0].keys()]とすることで、["description", "args", "kwargs", "expected"]を出力しています。

もちろん、",".join(_data[0].keys())で、"description,args,kwargs,expected"を出力しても問題ありません。

コメント

こんな感じで、テストのデータとロジックが分離できると、使い回しが楽になるので、使い勝手は良いかと思いっています。反面、データ用の関数が長くなるので見にくいと感じる人も多いかもですね。

もし気に入ったら使って下さいませ。

サンプル

ちなみに、サンプルはこんな感じになります。

def get_params_calculate_numbers():
    _data = [
        {
            "description": "正常系のテスト operation=add",
            "args": {
                "num1": 10,
                "num2": 5,
            },
            "kwargs": {
                "operation": "add",
            },
            "expected": {
                "resp": 15
            },
        },
        {
            "description": "正常系のテスト operationなし",
            "args": {
                "num1": -1,
                "num2": -1,
            },
            "kwargs": {},
            "expected": {
                "resp": -2
            },
        },
        {
            "description": "異常系のテスト TypeError",
            "args": {
                "num1": "ten",
                "num2": 5,
            },
            "kwargs": {
                "operation": "add",
            },
            "expected": {
                "errorType": TypeError,
                "errorMessage": "Both num1 and num2 must be int"
            },
        },
        {
            "description": "異常系のテスト ValueError",
            "args": {
                "num1": 1,
                "num2": 2,
            },
            "kwargs": {
                "operation": "devide",
            },
            "expected": {
                "errorType": ValueError,
                "errorMessage": "Unsupported operation"
            },
        },
    ]
    return list(_data[0].keys()), [
        (
            d["description"],
            d["args"].values(),
            d["kwargs"],
            d["expected"],
        )
        for d in _data
    ]


@pytest.mark.parametrize(*get_params_calculate_numbers())
def test_calculate_numbers_(description, args, kwargs, expected):
    """ 異常系のテスト """
    print(f"\n\n=== {sys._getframe().f_code.co_name} start ===")
    print(f"-- {description} --")

    if "errorType" in expected:
        with pytest.raises(expected["errorType"]) as e:
            calculate_numbers(*args, **kwargs)
        assert str(e.value) == expected["errorMessage"]
    else:
        resp = calculate_numbers(*args, **kwargs)
        assert resp == expected["resp"]

Discussion