🌛

線形分離できないデータの分類を可視化する

2024/12/24に公開

はじめに

今年、社内の勉強会で[第3版]Python機械学習プログラミング 達人データサイエンティストによる理論と実践を読む機会がありました。そこで紹介されていた、SVMを用いて非線形問題を解くことが面白いと感じたので、題材にして記事を書きます。具体的には、2次元上で線形分類できないデータセットに対し、カーネルSVMの基本的な考え方である 高次元へ射影して線形分離できるようにする といった内容を可視化することで理解を深めることを目的とします。尚、私自身は機械学習初心者であることを、あらかじめ断っておきます。

準備

基本的なライブラリを、インポートしておきます。

import matplotlib.pyplot as plt
import numpy as np

線形分離できないデータ

上記書籍では、NumPyの logical_xor(排他的論理和)を題材としていますが、ここでは別のデータとして、scikit-learnのmake_moons を用います。

from sklearn.datasets import make_moons

X, y = make_moons(n_samples=100, noise=0.1)

散布図を描画すると、moon感が出ます🌛。

from matplotlib.colors import ListedColormap

custom_cmap = ListedColormap(['red', 'green'])

plt.scatter(X[:, 0], X[:, 1], c=y, cmap=custom_cmap)
plt.title("Nonlinear Data (2D)")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.show()

補足

上記では、make_moons のパラメータである n_samplesnoise は以降の内容に特別影響しない箇所のため、適当に設定しています。[scikit-learn] 4. make_moonsによる三日月状データの生成 では、非常にわかりやすく、各パラメータに対する影響を記載してくださっています。

make_moons の返り値である Xy はそれぞれNumPy配列となっており、前者にはデータポイントの座標が、後者には対応するデータポイントのクラスラベル(つまり 0 or 1)が格納されています。plt.scatterによって、座標をクラスラベルに応じた色合いを割り当てて描画しています。

なぜ( coolwarm などを使用するのではなく) ListedColormap でわざわざ色を指定しているかって? それはもちろん投稿日が今日だからです 🎅

非線形分離

上記の図から、(2次元上)線形に分離する(= 赤と緑を分類する直線を引く)ことは難しそうです。これから、モデルを定義し学習することで分離します。ここでは、RBFカーネルSVMを用います。

from sklearn.svm import SVC

model = SVC(kernel='rbf', C=1.0, gamma=0.5)
model.fit(X, y)

モデルの学習をしたので、実際に予測してみます。

xx, yy = np.meshgrid(np.linspace(-1.5, 2.5, 200), np.linspace(-1, 1.5, 200))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

先ほどの散布図に、予測した決定境界をプロットします。

plt.contourf(xx, yy, Z, alpha=0.75, cmap=custom_cmap)

plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', cmap=custom_cmap)
plt.title("Decision Boundary with RBF Kernel")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.show()

うまいこと分離してくれています 🤶

補足

SVC(Support Vector Classification)は、SVM(Support Vector Machine)による分類タスク用のモデルです。SVMは、最大マージンを持つ超平面を見つけることでクラス間の決定境界を定義しようとするアルゴリズムで、線形分離可能であれば直接境界を見つけ、不可能であれば(まさにこの文章で今可視化しようとしている)高次元に射影するといった方法で、境界を見つけようとします。

カーネル関数としては、rbf(Radial Basis Function)を指定しており、書籍に則っています。定義はこちら:K(\bm{x}, \bm{x'}) = exp(-\gamma\|\bm{x}-\bm{x'}\|^2)

射影と線形分離

射影後の3次元目の値として利用するため、新しい特徴量を生成します。その後で、元のデータに列方向から結合しています。

def rbf_projection(X, gamma=0.5):
    return np.exp(-gamma * np.linalg.norm(X, axis=1)**2)

Z = rbf_projection(X)
X_projected = np.hstack((X, Z.reshape(-1, 1)))

準備も揃ったので、3次元にプロットします!

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

ax.scatter(X_projected[:, 0], X_projected[:, 1], X_projected[:, 2], c=y, cmap=custom_cmap)

ax.set_xlabel('Fature 1')
ax.set_ylabel('Fature 2')
ax.set_zlabel('RBF Feature')
plt.title("3D Projection of the Data")

plt.show()

このように捉えると、線形超平面で(概ね)分割できるようになります。

ax.plot_surface(xx, yy, zz, alpha=0.3, color='gray')

射影を引き戻すと、上記で非線形に分離していた図になります。

xx, yy = np.meshgrid(np.linspace(-1.5, 2.5, 50), np.linspace(-1, 1.5, 50))
zz = rbf_projection(np.c_[xx.ravel(), yy.ravel()])
grid_3d = np.c_[xx.ravel(), yy.ravel(), zz]

Z_3d = model.predict(grid_3d[:, :2]) 
Z_3d = Z_3d.reshape(xx.shape)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

ax.scatter(X_projected[:, 0], X_projected[:, 1], X_projected[:, 2], c=y, cmap=custom_cmap, edgecolors='k')
ax.contourf(xx, yy, Z_3d, zdir='z', offset=-0.2, alpha=0.3, cmap=custom_cmap)

ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
ax.set_zlabel('RBF Feature')
plt.title("Linear Separation in Higher Dimensions")

plt.show()

補足

先ほども出てきた linspace ですが、これは点を生成する関数です。上記では、-1.5 から 2.5 まで等間隔に50個の点を生成します。で、それをどうするかといえば、2つの1次元配列を入力として、 meshgrid により2次元の格子点を生成し、予測に使っています。

面白いっすね。上記コードは、こちらで動作確認をしています。

Discussion