💬

機械学習モデルの予測結果の違いを検定

2023/09/04に公開

前書き

違いの分かる人になりたいAIエンジニアです。

機械学習モデル(AIモデル)の性能を向上させるために、日々様々な試行錯誤が行われると思います。
改善策を考えて、新しく試した方法の結果が良くなったように見えたとしても、それは偶然起こる変化の範囲内なのか、偶然の結果とは言えない差があるのかを確かめる必要があると思います。

今回は機械学習モデルの予測結果を出した時に差があると言えるのか、言えないのかを検定するためのコードをまとめていきます。

やったこと

モデルの予測結果がデータフレームにまとめられていることを仮定し、検定を行う列名をまとめたリストと共に関数を実行する仕様にしています。

回帰問題でRMSE(Root Mean Squared Error)やMAPE(Mean Absolute Percentage Error)などの連続値の予測結果の比較を行うために作成しましたが、分類問題でも誤差などの連続値を使用することで利用できると思います。

同じテストデータでモデル性能を評価することを前提としているため、結果のデータは"対応のあるデータ"であることを仮定しています。

対応のあるデータの場合、正規性があるか否かで検定の方法が変わります。
正規性がある場合は対応のあるt検定、
正規性がない場合はウィルコクソンの符号順位検定
を行う必要があります。
予測結果に正規性があるか否かをシャピロウィルク検定で検定できるようにしています。

検定に関して参考にしたページ
https://datadriven-rnd.com/2021-01-24-154022/

今後も私が検定のために便利だと思った機能があれば順次追加していきます。

必要なライブラリ

import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats

検定を行うためのクラス

# 外れ値検出のクラス
class ResultsTest:
    def __init__(self):
        return

    # 正規性を検定する
    # pvalueが0.05より小さいの場合、正規分布である確率が低すぎるので、正規分布とは違う分布だと結論づける
    def shapiro_wilk(self, df, columns, significance_level = 0.05):
        for column in columns:
            # 2つ目の戻り値がp値
            pvalue = stats.shapiro(df[column])[1]
            print("---------------------------------------------------------------------")
            print(column + "のp値: " + str(pvalue))
            if pvalue < significance_level:
                print("p valueが" + str(significance_level) + "より小さいため、正規分布と有意に差がある(正規分布に従っていないとみなせる)")
                print("ウィルコクソンの符号順位検定で結果の違いを検定すべき")
            else:
                print("p valueが" + str(significance_level) + "より大きいため、正規分布と有意に差があるとは言えない")
                print("対応のあるt検定で結果の違いを検定すべき")
        return

    # Q-Qプロット
    def q_q_plot(self, df, columns):
        for column in columns:
            print("---------------------------------------------------------------------")
            print(column + "のQ-Qプロット")
            fig = plt.figure(figsize=(7, 5)) # Figureオブジェクトを作成
            ax = stats.probplot(df[column], plot=plt)
            plt.show()
            plt.close()
        return

    # 結果に違いがあるかを検定する
    # 対応のある正規性のないデータに対して行う
    # pvalueが0.05より小さい場合、同じである確率が低すぎるので有意に差があると判断する
    def wilcoxon_signed_rank(self, df, columns, significance_level = 0.05):
        for index, column in enumerate(columns):
            if index == len(print_columns) - 1:
                break
            for comp_index in range(index+1, len(columns)):
                comparison_column = print_columns[comp_index]
                # 2つ目の戻り値がp値
                pvalue = stats.wilcoxon(df[column], df[comparison_column], alternative='two-sided')[1]
                print("---------------------------------------------------------------------")
                print(column + "と" + comparison_column + "のp値: " + str(pvalue))
                if pvalue < significance_level:
                    print("p valueが" + str(significance_level) + "より小さいため、結果が有意に差がある")
                else:
                    print("p valueが" + str(significance_level) + "より大きいため、結果が有意に差があるとは言えない")
        return

    # 対応のあるt検定
    # 対応のある正規性のあるデータに対して行う
    # pvalueが0.05より小さい場合、同じである確率が低すぎるので有意に差があると判断する
    def ttest_rel(self, df, columns, significance_level = 0.05):
        for index, column in enumerate(columns):
            if index == len(print_columns) - 1:
                break
            for comp_index in range(index+1, len(columns)):
                comparison_column = print_columns[comp_index]
                # 2つ目の戻り値がp値
                pvalue = stats.ttest_rel(df[column], df[comparison_column])[1]
                print("---------------------------------------------------------------------")
                print(column + "と" + comparison_column + "のp値: " + str(pvalue))
                if pvalue < significance_level:
                    print("p valueが" + str(significance_level) + "より小さいため、結果が有意に差がある")
                else:
                    print("p valueが" + str(significance_level) + "より大きいため、結果が有意に差があるとは言えない")
        return

使用例

data_pathに使用するデータのpathを代入します。
columnsのリストの要素に検定を行う列名を指定します。

指定した列の全ての組み合わせで検定した結果を出力します。

import pandas as pd

data_path = "***"
df = pd.read_csv(data_path)
results_test = ResultsTest()

columns = ["column1", "column2", ..., "column"]

# シャピロウィルク検定
results_test.shapiro_wilk(df, columns)

# Q-Qプロット
results_test.q_q_plot(df, columns)

# ウィルコクソンの符号順位検定
results_test.wilcoxon_signed_rank(df, columns)

# 対応のあるt検定
results_test.ttest_rel(df, columns)

Discussion