🦁

Python で標準入出力のテストを書く方法(unittest, pytest)

2022/06/12に公開

Pythonで標準入出力のテストを書く機会があったのでメモ。

テスト対象のコード

こちらのコードを対象にテストを書くことにします。
標準入力から受け取った2つの数字を足した答えを標準出力に表示する関数です。

main.py
def print_sum():
    x = input()
    y = input()
    print(int(x) + int(y))

unittestでのテスト方法

unittestでの標準入出力のテストはこちらです。

test_main.py
import sys
import io
import unittest


def stub_stdin(testcase_inst, inputs):
    stdin = sys.stdin

    def cleanup():
        sys.stdin = stdin

    testcase_inst.addCleanup(cleanup)
    sys.stdin = io.StringIO(inputs)


def stub_stdouts(testcase_inst):
    stdout = sys.stdout

    def cleanup():
        sys.stdout = stdout

    testcase_inst.addCleanup(cleanup)
    sys.stdout = io.StringIO()


class MainTestCase(unittest.TestCase):
    def test_main(self):
        # 標準入力をモック
        stub_stdin(self, "1\n4\n")
        # 標準出力をモック
        stub_stdouts(self)

        print_sum()

        self.assertEqual(sys.stdout.getvalue(), "5\n")

stub_stdinとstub_stdoutsで、sys.stdinsys.stdoutをStringIOと入れ替えています。
そして、addCleanupのメソッドを使い、テスト終了時には元のsys.stdinsys.stdoutに戻るように設定しています。

Pytestでのテスト方法

Pytestではunittestに比べて大分シンプルに標準入力・標準出力のテストを書くことができます。

test_main.py
def test_main(capsys, monkeypatch):
    # 標準入力をモック
    monkeypatch.setattr('sys.stdin', io.StringIO("1\n4\n"))

    print_sum()

    # 標準出力のキャプチャを取得
    captured = capsys.readout()

    assert captured.out == "5\n"

monkeypatchでsys.stdinをmockして、StringIOを返すようにしています。そしてcapsysで標準入力と標準出力のキャプチャを取得しています。
あとは、キャプチャ結果でアサーションするだけでOKです!
分かりやすいです!

参考

https://stackoverflow.com/questions/38861101/how-can-i-test-the-standard-input-and-standard-output-in-python-script-with-a-un

https://docs.pytest.org/en/7.1.x/how-to/capture-stdout-stderr.html

Discussion