🍮

unittestで特定の関数呼び出しをテスト全体でアサーションする

2023/10/03に公開

小ネタです🍣

Pythonでプログラミングをするにあたって print はとても便利です
ビルトイン関数のためimportは不要ですし、実装やテストの最中にちょっとだけ現在の状態を確認したいときに、誰しもお世話になったことがあるのでないかと思います

https://docs.python.org/ja/3.11/library/functions.html#print

しかしながら、Pythonにはロギングの仕組みとしてlogging モジュールが用意されており、プロダクションコードにおいてはこちらを利用すべきです

https://docs.python.org/ja/3/howto/logging.html

print 関数にはログレベルのような機構が存在しないため、うっかり本番コードに書き残してしまうと、プログラムの実行コンテナの標準出力に予期せぬ出力をおこなうこととなり、監視やトラブルシュートのノイズになってしまいます

テストにおいても、以下のように実行中の内容に挟まってデバッグ文字列が出てしまっており邪魔…というのは、一定規模の開発ではあるあるなのでないかと思います

$ python -m unittest assert_print.py
.hello, logger wip
.hello, print
.
----------------------------------------------------------------------
Ran 3 tests in 0.000s

OK

こうした状況を予防するための方法として、flake8-print のような追加のチェックルールを導入したり、 reviewdogのようなツールをCIに導入するのもよいですが、
テストの実行中に特定の関数(今回はprint)が呼び出された時にエラーとする ということができると、柔軟にハンドリングができそうです

https://github.com/reviewdog/reviewdog

ということで作ってみました🍨

作ったもの

以下に、今回実装した CustomTestCase と、それを利用するサンプルコードを示します

assert_print.py
import logging
from unittest import TestCase
from unittest.mock import patch

logger = logging.getLogger(__name__)


# テスト共通基底クラス
class CustomTestCase(TestCase):
    # テスト関数内でprintが呼ばれた時にエラーとするかどうか
    # 特定のテストでエラーを抑止したい場合は、このフラグをFalseにする
    assert_builtins_print = True

    def setUp(self) -> None:
        super().setUp()

        def _create_assert_patch(target: str):
            def _assert(*args, **kwargs):
                raise AssertionError(f'{target} is not allowed')

            # patchをnewで生成すると、テスト関数にmockオブジェクトが渡されない
            return patch(target, new=_assert)

        if self.assert_builtins_print:
            # builtins.printが呼び出された時にエラーとする
            self._patcher_builtins_print = _create_assert_patch('builtins.print')
            self._patcher_builtins_print.start()

    def tearDown(self) -> None:
        super().tearDown()

        if self.assert_builtins_print:
            self._patcher_builtins_print.stop()


# テスト対象の関数
def greet_with_logger(name: str) -> str:
    """ロガー経由で挨拶する関数"""
    message = f'hello, {name}'
    logger.info(message)
    return message


def greet_with_print(name: str) -> str:
    """print経由で挨拶する関数"""
    message = f'hello, {name}'
    print(message)
    return message


# テストクラス
class GreetingTestCase(CustomTestCase):
    def test_greet_with_logger(self) -> None:
        """ロガー経由で挨拶する関数のテスト"""
        # printを使っていなければエラーとならない
        self.assertEqual(greet_with_logger('logger'), 'hello, logger')

    def test_greet_with_logger__debug(self) -> None:
        """ロガー経由で挨拶する関数のテスト(デバッグ)"""
        message = greet_with_logger('logger wip')
        # テスト関数側でprintを呼び出すとAssertionErrorが発生する
        print(message)
        self.assertEqual(message, 'hello, logger wip')

    def test_greet_with_print(self) -> None:
        """print経由で挨拶する関数のテスト"""
        # 関数内部でprintが呼ばれるとAssertionErrorが発生する
        self.assertEqual(greet_with_print('print'), 'hello, print')

assert_builtins_printTrue の状態で上記テストを実行すると、 print が呼び出されたタイミングで AssertionError が発生し、テスト実行が中断されます

$ python -m unittest assert_print.py
.FF
======================================================================
FAIL: test_greet_with_logger__debug (assert_print.GreetingTestCase)
ロガー経由で挨拶する関数のテスト(デバッグ)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/tkhs/assert_print.py", line 62, in test_greet_with_logger__debug
    print(message)
  File "/home/tkhs/assert_print.py", line 19, in _assert
    raise AssertionError(f'{target} is not allowed')
AssertionError: builtins.print is not allowed

======================================================================
FAIL: test_greet_with_print (assert_print.GreetingTestCase)
print経由で挨拶する関数のテスト
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/tkhs/assert_print.py", line 68, in test_greet_with_print
    self.assertEqual(greet_with_print('print'), 'hello, print')
  File "/home/tkhs/assert_print.py", line 47, in greet_with_print
    print(message)
  File "/home/tkhs/assert_print.py", line 19, in _assert
    raise AssertionError(f'{target} is not allowed')
AssertionError: builtins.print is not allowed

----------------------------------------------------------------------
Ran 3 tests in 0.001s

FAILED (failures=2)

関数ごとにunittest.mock.patch を利用するのと比較して、テストの共通クラス内で一律 print がモックされるため、関数ごとにデコレータを書かなくてよくなります

また、モック時に new=_assert のような形式でモックオブジェクトを生成することで、テスト関数の引数にモックオブジェクトが渡されなくなるため、既存のテストを大きく壊さずに導入することができるものと思います

そんだけ😌

Discussion