🔖

HypothesisでpandasのDataFrameを複雑に扱う

2024/01/08に公開

概要

プロパティベーステストというものがあります。簡単に言えば仕様に定義された入力範囲に従ってランダムに値を大量に生成してその全件で狙った性質(Property)を満たすかテストするというものです。PythonにおいてはHypothesisライブラリで行なえます。HypythesisはpandasのDataFrame型も扱え、当記事ではあるケースを想定してその使用方法をまとめます。
基本的なHypothesisの使い方は他の記事を参考にされて下さい。

環境

Ubuntu 20.04
Python 3.12.0
Hypothesis及びDataFrameを扱うため以下が必要です。

pip install hypothesis
pip install hypothesis[numpy,pandas]

他ライブラリは足りなければ適宜、、、

想定ケース

ある大本のデータテーブル(DataFrame)があり、それをコンフィグ情報(DataFrame)に従ってフィルタを掛ける関数filtをテストしたいとします。

大本のデータテーブルは例えば以下のようになっています。

date,type
2023-01-01,A
2023-06-01,A
2024-12-31,A
2020-01-01,B
2020-06-01,B
2020-12-31,B
2020-01-01,C
2020-06-01,C
2020-12-31,C
2024-12-31,C

コンフィグ情報の例は以下です。

type,start,end
A,2023-01-01,2023-12-31
B,2023-01-01,2023-06-30
C,2020-01-01,2022-12-31

今回テスト対象としているfiltは以下です。
コンフィグで指定している各typeのstartとendの範囲で大本データの行を切り出すというものです。

import pandas as pd
import numpy as np

def filt(df_in: pd.DataFrame, df_conf: pd.DataFrame) -> pd.DataFrame:
    def get_in_range(row_conf: pd.DataFrame) -> pd.DataFrame:
        ty = row_conf["type"]
        start = np.datetime64(row_conf["start"])
        end = np.datetime64(row_conf["end"])
        df_out_tmp = df_in[
            (df_in["type"] == ty) &
            (start <= df_in["date"]) & (df_in["date"] <= end)
        ]
        return df_out_tmp
    df_l = df_conf.apply(get_in_range, axis=1)
    df_out = pd.concat(df_l.to_list(), axis=0).reset_index(drop=True)
    # Intended bug !
    df_out.loc[(df_out["date"] > np.datetime64("2125-01-01")), "type"] = "Z"
    return df_out

# Intended bug!下1行はテストの効果実証のためにわざと混入させているバグで、2125年以降の行はすべてのtypeがZに書き換わってしまうというものです。

先述の大本データとコンフィグでfiltを実行すると以下のようになります。
上記のバグはこのケースでは発生しないので、これでテストを終えると潜在バグとして残ってしまいます。

date,type
2023-01-01,A
2023-06-01,A
2020-01-01,C
2020-06-01,C
2020-12-31,C

コンフィグ情報のstrategy

typeはA~Zで重複の無い指定、
startは1900~3000年の範囲で、endはstartから10年後の範囲で指定しています。
compositeが何かはこの記事がシンプルな例を示しています。

import string
from hypothesis import strategies as st
from hypothesis import given, settings
import hypothesis.extra.pandas as hypd
import numpy as np
import pandas as pd
import datetime
from func import filt

date_min = datetime.date(1900, 1, 1)
date_max = datetime.date(3000, 1, 1)

@st.composite
def gen_config_st(draw):
    row_num_src = draw(st.integers(min_value=1, max_value=20))

    ty_src = draw(
        st.lists(
            st.sampled_from(list(string.ascii_uppercase),),
            min_size=row_num_src, max_size=row_num_src
        )
    )
    ty = list(set(ty_src))
    row_num = len(ty)
    start = draw(
        st.lists(
            st.dates(
                min_value=date_min,
                max_value=date_max
            ),
            min_size=row_num, max_size=row_num
        )
    )
    end = [draw(
        st.dates(
            min_value=s_date,
            max_value=s_date+datetime.timedelta(days=365*10))
        ) for s_date in start]

    return pd.DataFrame({
        "type": ty,
        "start": start,
        "end": end,
    })

ここでstrategy設計方針を語る

想定ケースの性質からコンフィグデータを先に定義しました。私のソースコード上でも同様であり大本データのstrategyはconfigのstrategyを参照します。これはfiltの動作を保証するための論理構造をシンプルにするためです。
filtに求められる性質は”狙ったデータを取れる”ことと”狙っていないデータは取らない”ことです。テストのために大本データを生成するには、コンフィグで何を狙っているかを知り、それに従って狙ったデータと狙っていないデータを生成してそれをマージして大本データとします。filtの処理後のデータが、狙ったデータすべてを含み、なおかつ狙っていないデータを1つも含んでいなければfiltの動作が保証されたと言えるわけです。
この考えは少々トリッキーでした。未来の視点を過去の視点に変換するようで、、、

ちなみに大本データのstrategyとコンフィグのstrategyを個別に依存しないように書いて、それらから狙ったデータと狙っていないデータを生成してテストと言う方針も考えられますが、それではテストコード内で殆どfiltと同じ処理を書くこととなり、テストの意味がありません。

大本データのstrategy

上記方針の元、以下のように書きました。
まず先述のコンフィグのstrategyでコンフィグのDataFrameを取得し、その各行で狙ったデータ[startとendの間]と狙っていないデータ[startより前+endより後]を生成し、最後に(コンフィグ、狙ったデータ、狙っていないデータ)を出力しています。

@st.composite
def gen_data_st(draw):
    df_conf = draw(gen_config_st())

    df_bef_list = []
    df_inside_list = []
    df_aft_list = []

    for index, row in df_conf.iterrows():
        # Draw the DataFrames for each row
        df_bef_start = draw(hypd.data_frames([
            hypd.column(
                "date",
                elements=st.dates(
                    max_value=row["start"] - datetime.timedelta(days=1))),
            hypd.column("type", elements=st.just(row["type"]))
        ]))
        df_inside = draw(hypd.data_frames([
            hypd.column(
                "date",
                elements=st.dates(
                    min_value=row["start"],
                    max_value=row["end"])),
            hypd.column("type", elements=st.just(row["type"]))
        ]))
        df_aft_end = draw(hypd.data_frames([
            hypd.column(
                "date",
                elements=st.dates(
                    min_value=row["end"] + datetime.timedelta(days=1))),
            hypd.column("type", elements=st.just(row["type"]))
        ]))

        # Append the drawn DataFrames to the lists
        df_bef_list.append(df_bef_start)
        df_inside_list.append(df_inside)
        df_aft_list.append(df_aft_end)

    # Concatenate the lists of DataFrames
    df_outside = pd.concat(df_bef_list + df_aft_list, axis=0)
    df_inside = pd.concat(df_inside_list, axis=0)

    return df_conf, df_inside, df_outside

テストコード

先述の通り、”狙ったデータを取れる”ことと”狙っていないデータは取らない”ことを確認しています。

def all_rows_isin(df1, df2):
    merged = pd.merge(df1, df2, how='left', indicator=True)
    all_rows_present = (merged['_merge'] == 'both').all()
    return all_rows_present

def all_rows_NOTin(df1, df2):
    merged = pd.merge(df1, df2, how='left', indicator=True)
    all_rows_NOT_present = ~((merged['_merge'] == 'both').any())
    return all_rows_NOT_present
    
@given(gen_data_st())
def test_data(tup_df):
    df_conf, df_inside, df_outside = tup_df
    df_in = pd.concat([df_inside, df_outside], axis=0)
    df_out = filt(df_in, df_conf)
    assert all_rows_isin(df_inside, df_out)
    assert all_rows_NOTin(df_outside, df_out)

pytest実行

vscodeで実行しました。普通のpytestでも動くと思います。実行には数分要しました。
filtの以下の有無でテスト成功/失敗が切り替わることが解ると思います。

    # Intended bug !
    df_out.loc[(df_out["date"] > np.datetime64("2125-01-01")), "type"] = "Z"

成功時

========================= 1 passed, 1 warning in 9.06s =========================
Finished running tests!

失敗時

E       assert False
E        +  where False = all_rows_isin(         date type\n0  2126-01-01    B,          date type\n0  2126-01-01    Z)
E       Falsifying example: test_data(
E           tup_df=(
E                  type       start         end
E                0    A  2000-01-01  2000-01-01
E                1    B  2126-01-01  2126-01-01
E            , 
E                         date type
E                0  2126-01-01    B
E            , 
E                Empty DataFrame
E                Columns: [date, type]
E                Index: []
E            ),
E       )
E       Explanation:
E           These lines were always and only run by failing examples:
E               config.py:911
E               common.py:312
E               common.py:313
E               common.py:314
E               construction.py:779
E               (and 104 more with settings.verbosity >= verbose)

tests/test_func.py:117: AssertionError

ソースコード

func.py
import pandas as pd
import numpy as np


def filt(df_in: pd.DataFrame, df_conf: pd.DataFrame) -> pd.DataFrame:
    def get_in_range(row_conf: pd.DataFrame) -> pd.DataFrame:
        ty = row_conf["type"]
        start = np.datetime64(row_conf["start"])
        end = np.datetime64(row_conf["end"])
        df_out_tmp = df_in[
            (df_in["type"] == ty) &
            (start <= df_in["date"]) & (df_in["date"] <= end)
        ]
        return df_out_tmp
    df_l = df_conf.apply(get_in_range, axis=1)
    df_out = pd.concat(df_l.to_list(), axis=0).reset_index(drop=True)
    # Intended bug !
    # df_out.loc[(df_out["date"] > np.datetime64("2125-01-01")), "type"] = "Z"
    return df_out


if __name__ == "__main__":
    df_in = pd.read_csv("data.csv", parse_dates=["date"])
    df_conf = pd.read_csv("config.csv", parse_dates=["start", "end"])
    df_out = filt(df_in, df_conf)
    print(df_out)
    df_out.to_csv("df_out.csv", index=False)
test_func.py
import string
from hypothesis import strategies as st
from hypothesis import given, settings
import hypothesis.extra.pandas as hypd
import numpy as np
import pandas as pd
import datetime
from func import filt

date_min = datetime.date(1900, 1, 1)
date_max = datetime.date(3000, 1, 1)

@st.composite
def gen_config_st(draw):
    row_num_src = draw(st.integers(min_value=1, max_value=20))

    ty_src = draw(
        st.lists(
            st.sampled_from(list(string.ascii_uppercase),),
            min_size=row_num_src, max_size=row_num_src
        )
    )
    ty = list(set(ty_src))
    row_num = len(ty)
    start = draw(
        st.lists(
            st.dates(
                min_value=date_min,
                max_value=date_max
            ),
            min_size=row_num, max_size=row_num
        )
    )
    end = [draw(
        st.dates(
            min_value=s_date,
            max_value=s_date+datetime.timedelta(days=365*10))
        ) for s_date in start]

    return pd.DataFrame({
        "type": ty,
        "start": start,
        "end": end,
    })


@st.composite
def gen_data_st(draw):
    df_conf = draw(gen_config_st())

    df_bef_list = []
    df_inside_list = []
    df_aft_list = []

    for index, row in df_conf.iterrows():
        # Draw the DataFrames for each row
        df_bef_start = draw(hypd.data_frames([
            hypd.column(
                "date",
                elements=st.dates(
                    max_value=row["start"] - datetime.timedelta(days=1))),
            hypd.column("type", elements=st.just(row["type"]))
        ]))
        df_inside = draw(hypd.data_frames([
            hypd.column(
                "date",
                elements=st.dates(
                    min_value=row["start"],
                    max_value=row["end"])),
            hypd.column("type", elements=st.just(row["type"]))
        ]))
        df_aft_end = draw(hypd.data_frames([
            hypd.column(
                "date",
                elements=st.dates(
                    min_value=row["end"] + datetime.timedelta(days=1))),
            hypd.column("type", elements=st.just(row["type"]))
        ]))

        # Append the drawn DataFrames to the lists
        df_bef_list.append(df_bef_start)
        df_inside_list.append(df_inside)
        df_aft_list.append(df_aft_end)

    # Concatenate the lists of DataFrames
    df_outside = pd.concat(df_bef_list + df_aft_list, axis=0)
    df_inside = pd.concat(df_inside_list, axis=0)

    return df_conf, df_inside, df_outside


def all_rows_isin(df1, df2):
    merged = pd.merge(df1, df2, how='left', indicator=True)
    all_rows_present = (merged['_merge'] == 'both').all()
    return all_rows_present


def all_rows_NOTin(df1, df2):
    merged = pd.merge(df1, df2, how='left', indicator=True)
    all_rows_NOT_present = ~((merged['_merge'] == 'both').any())
    return all_rows_NOT_present


@given(gen_config_st())
def test_config(df_conf):
    print(df_conf.shape)
    print(df_conf)
    assert (df_conf['start'] <= df_conf['end']).all()


# @settings(max_examples=500)
@given(gen_data_st())
def test_data(tup_df):
    df_conf, df_inside, df_outside = tup_df
    df_in = pd.concat([df_inside, df_outside], axis=0)
    df_out = filt(df_in, df_conf)
    assert all_rows_isin(df_inside, df_out)
    assert all_rows_NOTin(df_outside, df_out)


if __name__ == "__main__":
    # example_df = gen_config_st().example()
    # print(example_df)
    df_conf, df_inside, df_outside = gen_data_st().example()
    df_conf.to_csv("df_conf.csv", index=False)
    df_inside.to_csv("df_inside.csv", index=False)
    df_outside.to_csv("df_outside.csv", index=False)

Discussion