🐈

imbalanced-learnを利用してアンダー/オーバーサンプリングを実施してみた

に公開

今回はimbalanced-learnを利用してデータセットの偏りを調整する方法を試してみました。機械学習ではデータの分布がとても重要であり、偏ったデータ分布は好ましくない場合が多いです。そのような場合にデータの偏りを補正するためのライブラリとしてimbalanced-learnがあり、今回はそれを利用してみました。

データの偏りとは?

文字通りですが、データに偏りがある状態をいいます。

例えばある病気について、診察対象の人が病気に罹患しているかしていないかをまとめたデータセットがあったとします。仮にその病気に罹患している人が極めて少数の場合、このデータセットは大半が罹患していない人のデータであり罹患している方のデータはごく少数になります。例えば罹患していない人のデータ内の割合が90%の場合、学習されるモデルは何も考えずに全ての人が病気に罹患していないと答えても精度が90%となってしまいます。もちろんRecallを計算するとRecall=0となりおかしいことには気づきますが、単純な精度という数字だけで判断すると上記のような問題が発生します。

先ほど挙げた例のようにデータに偏りがあると、モデルの開発が困難になる、または見かけ上精度がよく見えても実際には使い物にならないものが出来上がっている状況になります。このような現象を止めるためにデータセットのサンプル比率を変更する手法がしばしば適用されます。例えば以下のような方法が有名かと思います。

  • アンダーサンプリング:データ比率が多いデータの件数を削減し、比率が少ないデータと同等のデータ数または許容範囲にデータ数を調整する
  • オーバーサンプリング:データ比率が少ないデータの件数を増加させ、比率が多いデータと同等のデータ数または許容範囲までデータ数を調整する

今回紹介するimbalanced-learnはこれらの手法を提供してくれます。

imbalanced-learnとは?

imbalanced-learnはscikit-learnに依存したライブラリであり、クラス間のデータ数に偏りがあるデータに対して調整をかけることができるものとなっています。

https://imbalanced-learn.org/stable/index.html

早速使ってみる

今回はKMeansを利用したアンダーサンプリングとSMOTEを利用したオーバーサンプリングについて試してみました。

環境構築

まずはuvを利用して環境構築をします。

uv init imbalanced_learn_tutorial -p 3.12
cd imbalanced_learn_tutorial
uv add imbalanced-learn scikit-learn matplotlib

アンダーサンプリングの実装

一つ目にアンダーサンプリングを実施します。imblearn.under_sampling.ClusterCentroidsを利用するとKMeansベースのアンダーサンプリングを適用できます。この手法では、KMeansで生成されるクラスターの重心でデータを置き換えていく手法になります。実装をすると以下のようになります。

undersample.py
import matplotlib.pyplot as plt
from collections import Counter
from imblearn.under_sampling import ClusterCentroids
from sklearn.datasets import make_classification


def plot(X, y, X_resampled, y_resampled):
    plt.subplot(1, 2, 1)
    plt.scatter(
        X[y==0, 0], X[y==0, 1],
        color="red", marker="^"
    )
    plt.scatter(
        X[y==1, 0], X[y==1, 1],
        color="green", marker="*"
    )
    plt.scatter(
        X[y==2, 0], X[y==2, 1],
        color="blue", marker="o"
    )
    plt.title("Original data")
    plt.subplot(1, 2, 2)
    plt.scatter(
        X_resampled[y_resampled==0, 0], X_resampled[y_resampled==0, 1],
        color="red", marker="^"
    )
    plt.scatter(
        X_resampled[y_resampled==1, 0], X_resampled[y_resampled==1, 1],
        color="green", marker="*"
    )
    plt.scatter(
        X_resampled[y_resampled==2, 0], X_resampled[y_resampled==2, 1],
        color="blue", marker="o"
    )
    plt.title("Undersampled data")
    plt.show()


X, y = make_classification(n_samples=5000, n_features=2, n_informative=2,
                           n_redundant=0, n_repeated=0, n_classes=3,
                           n_clusters_per_class=1,
                           weights=[0.01, 0.05, 0.94],
                           class_sep=0.8, random_state=0)


cc = ClusterCentroids(random_state=0)
X_resampled, y_resampled = cc.fit_resample(X, y)
plot(X, y, X_resampled, y_resampled)

print("Original data's count: ", sorted(Counter(y).items()))
print("Resampled data's count: ", sorted(Counter(y_resampled).items()))

まずはsklearn.datasets.make_classificationでダミーデータを生成します。weightsでクラスごとに1:5:94の割合でデータを生成させることで、クラス間のデータ数の偏りを実現しています。

X, y = make_classification(n_samples=5000, n_features=2, n_informative=2,
                           n_redundant=0, n_repeated=0, n_classes=3,
                           n_clusters_per_class=1,
                           weights=[0.01, 0.05, 0.94],
                           class_sep=0.8, random_state=0)

次にimblearn.under_sampling.ClusterCentroidsを利用してアンダーサンプラーを作成します。scikit-learnを利用する要領でfit_resampleとすると指定したデータを確認してデータ数が少数のデータ数に合わせるようにサンプリングが実行されます。

cc = ClusterCentroids(random_state=0)
X_resampled, y_resampled = cc.fit_resample(X, y)

最後にplot関数では、アンダーサンプリング前後でデータの分布がどのように変わるかをまとめて表示することができます。

早速コードを実行してみましょう。可視化の結果をみると、青と緑のデータが最初は多かったものの、サンプリングしたあとはデータ数が赤と同等になっていることが確認できます。実際にデータ数をみると、サンプリング後のデータ数は全て64で同じになっています。

uv run undersample.py

# 結果
Original data's count:  [(np.int64(0), 64), (np.int64(1), 262), (np.int64(2), 4674)]
Resampled data's count:  [(np.int64(0), 64), (np.int64(1), 64), (np.int64(2), 64)]

オーバーサンプリングの実装

二つ目にオーバーサンプリングを実施します。imblearn.over_sampling.SMOTEを利用するとSMOTEを利用してオーバーサンプリングができます。SMOTEは簡単に言うと二つのデータからその間に新たなデータを内挿していくことでデータ数を増やす手法になります。早速コードをみてみましょう。

oversample.py
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.datasets import make_classification
from imblearn.over_sampling import SMOTE


def plot(X, y, X_resampled, y_resampled):
    plt.subplot(1, 2, 1)
    plt.scatter(
        X[y==0, 0], X[y==0, 1],
        color="red", marker="^"
    )
    plt.scatter(
        X[y==1, 0], X[y==1, 1],
        color="green", marker="*"
    )
    plt.scatter(
        X[y==2, 0], X[y==2, 1],
        color="blue", marker="o"
    )
    plt.title("Original data")
    plt.subplot(1, 2, 2)
    plt.scatter(
        X_resampled[y_resampled==0, 0], X_resampled[y_resampled==0, 1],
        color="red", marker="^"
    )
    plt.scatter(
        X_resampled[y_resampled==1, 0], X_resampled[y_resampled==1, 1],
        color="green", marker="*"
    )
    plt.scatter(
        X_resampled[y_resampled==2, 0], X_resampled[y_resampled==2, 1],
        color="blue", marker="o"
    )
    plt.title("Oversampled data")
    plt.show()


X, y = make_classification(n_samples=5000, n_features=2, n_informative=2,
                           n_redundant=0, n_repeated=0, n_classes=3,
                           n_clusters_per_class=1,
                           weights=[0.01, 0.05, 0.94],
                           class_sep=0.8, random_state=0)

ros = SMOTE(random_state=0)
X_resampled, y_resampled = ros.fit_resample(X, y)
plot(X, y, X_resampled, y_resampled)

print("Original data's count: ", sorted(Counter(y).items()))
print("Resampled data's count: ", sorted(Counter(y_resampled).items()))

先ほどと違く部分としてはSMOTEの実行部分になります。SMOTEを利用する場合はimbalanced_learn.over_sampling.SMOTEを利用して、fit_resampleを利用することでオーバーサンプリングを実施します。

ros = SMOTE(random_state=0)
X_resampled, y_resampled = ros.fit_resample(X, y)

それでは早速実行してみましょう。可視化結果をみると、SMOTEでオーバーサンプリングすると元のデータにはなかった多数のデータが生成されていることが確認できました。データ数を確認すると一番データがあったクラス2のデータ数に一致するように他のクラスのデータ数が増えていることが確認できます。

uv run oversample.py

# 結果
Original data's count:  [(np.int64(0), 64), (np.int64(1), 262), (np.int64(2), 4674)]
Resampled data's count:  [(np.int64(0), 4674), (np.int64(1), 4674), (np.int64(2), 4674)]

まとめ

今回はimbalanced-learnを利用してKMeansベースのアンダーサンプリングとSMOTEによるオーバーサンプリングを試してみました。どのようなデータを取り扱っているかによって利用できるサンプリング手法も変わりますし、必ずしもデータのバランスが悪い状況を正す必要がない場合もあるかと思うので、ケースバイケースでデータのバランスを調整していただければと思います。今回はデータを調整しただけなので、次回はモデルの精度と合わせて検証できればと思います。

Discussion