🤖

[機械学習]KNN(K近傍法)の仕組み

commits10 min read

概要

この記事ではK nearest neighbors/K近傍法(以下、KNNと呼ぶ)について説明する。

KNNは、テストデータが入力されたときにそれに近い\bf{k}個のトレーニングデータを取り、それらのラベルの多数決を採ることで、テストデータのラベルを予測するモデルである。

本記事では、身長・体重(特徴量)のデータを用いて、性別(ラベル)判断をする具体例を用いてKNNの挙動を説明する。

また、KNNの数式での表現についても扱う。

はじめに

KNNは分類手法のひとつだ。

KNNはEvelyn FixとJoseph Hodgesという二人の統計学者によって1951年に発表された[1]。
現在でも使われることがあり、改良もされている。[2]

KNNは「テストケースから近い順に適切な個数のトレーニングデータを取得すればテストケースのラベルを特定出来る」という仮説を置いたモデルだ。

前提知識

特徴量とトレーニングデータ/テストデータ

特徴量とはそれぞれのデータの特徴を表す値の集合で、ベクトルを用いて表すことが出来る。

トレーニングデータとは、モデルが「入力されたデータ」(特徴量)から「何らかの結果」(ラベル)を予測出来るよう訓練をするために用いられるデータだ。
こういった目的に応えるため、トレーニングデータには入力データだけでなく、予測結果であるラベル情報も付与されている。
※トレーニングセット(トレーニングデータの集合)をモデルに学習させることを教師あり学習と言う。

一方でテストデータはモデルの性能を評価するためのデータだ。モデルはテストデータのラベルを参照出来ない。

パラメトリック(parametric)とノンパラメトリック(non-parametric)

パラメーター(parameter)とは分布を規程する変数のことだ。例えば正規分布では平均と分散がパラメーターだ。

パラメトリックモデルのパラメーターの個数が「固定的な数値である」のに対してノンパラメトリックモデルは、パラメーターの個数が「トレーニングデータの増加に伴い増加する」という違いがある。
※ノンパラメトリックモデルはデータそのものがパラメーターの役割をしているとも言える。

パラメトリックなモデルは比較的すぐに使うことが出来るという優れた点を持つ一方でデータ分布の性質に対してより強力な仮定を置かなければならないという欠点もある。
ノンパラメトリックモデルはより柔軟に運用出来る利点を持つが、大規模なデータに対しては計算量的に困難をきたすこともある。

ユークリッド距離(euclidean distance)

ユークリッド空間(要は普通の幾何学の世界)における二点間の直線距離のことをユークリッド距離と言う。

Euclidean distance / Wikipediaより

上記の図よりもわかる通り、ユークリッド距離は座標上の各軸の差の2乗の和の正の平方根で計算出来る。

1.KNNとはなにか

KNNはノンパラメトリックな分類手法の一つだ。KNNはトレーニングセットの中からテスト入力の特徴量(feature)\bf{x}からユークリッド距離的に最も近しい\bf{k}個のデータを見て、その中の何個が指定したラベルと同値となるかを数え上げてその割合を推定値として返してくれる。

Machine Learning: A Probabilistic Perspective (Adaptive Computation and Machine Learning series) P.16 Figure1.14より

上記の図の例で言えばテストケースx_1やx_2がK=3の数だけトレーニングセットのポイントを見て、ラベル0の時の割合やラベル1の時の割合を返すようになっている。
※ちなみにKNNでよく使われる距離の測度基準は今回も用いているユークリッド距離だが、他の測度基準、例えばマンハッタン距離などを用いることも出来る。

2.KNNを実際に用いた具体例

KNNのイメージを掴むために、以下のような架空のデータを用いた分類をしてみる。
身長と体重を特徴量とし、性別をラベルとした。
※説明のために若干極端な数値設定にしてある。

Name 特徴量1(身長) 特徴量2(体重) ラベル(性別)
Train-1 180 60
Train-2 170 55
Train-3 150 45
Train-4 150 50
Train-5 160 55
Train-6 150 52

筆者作成の架空データ:同年代の男女の身長、体重、性別データのテーブル

トレーニングデータに対して下記3つのテストケースを用意する。

Name 特徴量1(身長) 特徴量2(体重) ラベル(性別)
Test-7 175 40
Test-8 155 65
Test-9 150 45

上記のデータセットを用いて(1)実際に計算を行いそれを追いかけていく。その後に(2)scikit-learnを用いて分類をする。

2.1.KNNで行われる計算

テストデータiとトレーニングデータjとのユークリッド距離d_{ij}を計算する。
データiの身長、体重の特徴量をそれぞれ(x_i1, x_i2)とする。

\| d_{ij} \| = \sqrt{(x_i1 - x_j1)^2 + (x_i2 - x_j2)^2}

で計算出来る。それぞれのテストデータ、トレーニングデータ間のユークリッド距離の計算結果を下記のテーブルにまとめた。

Train-1 Train-2 Train-3 Train-4 Train-5 Train-6
Test-7 20.6 15.8 25.4 26.9 21.2 27.7
Test-8 25.4 18.0 20.6 15.8 11.1 13.9
Test-9 33.5 22.3 0.0 5.0 14.1 7.0

※少数第1位以下は切り捨て

以上から、

  • Test-7
    近しいトレーニングデータはTrain-2、Train-1、Train-5。
    男性割合は\frac{2}{3}女性割合は\frac{1}{3}
    =>以上より男性
  • Test-8
    近しいトレーニングデータはTrain-5、Train-6、Train-4。
    男性割合は\frac{2}{3}女性割合は\frac{1}{3}
    =>以上より男性
  • Test-9
    近しいトレーニングデータはTrain-3、Train-4、Train-6。
    男性割合は\frac{1}{3}女性割合は\frac{2}{3}
    =>以上より女性

上記を実際にグラフにしたものが下記だ。

念のため上記をPythonスクリプトでも試せるように記載しておく。
Pythonを普段使いされている方はコピペ等して実環境で試してみて欲しい。

Pythonスクリプト
import math

X_train = [
        [180, 60],
        [170, 55],
        [150, 45],
        [150, 50],
        [160, 55],
        [150, 52],
        ]

X_test = [
        [175, 40],
        [155, 65],
        [150, 45]
        ]

if __name__ == "__main__":
    distances = [[None] * len(X_train) for _ in range(len(X_test))]
    for i, test in enumerate(X_test):
        for j, train in enumerate(X_train):
            distance = math.sqrt((train[0] - test[0]) ** 2 + (train[1] - test[1]) ** 2)
            distances[i][j] = [j+1, distance]
    for dl in distances:
        s_dl = sorted(dl, key=lambda x: x[1])
        print(s_dl)
--------------
[[2, 15.811388300841896], [1, 20.615528128088304], [5, 21.213203435596427], [3, 25.495097567963924], [4, 26.92582403567252], [6, 27.730849247724095]]

[[5, 11.180339887498949], [6, 13.92838827718412], [4, 15.811388300841896], [2, 18.027756377319946], [3, 20.615528128088304], [1, 25.495097567963924]]

[[3, 0.0], [4, 5.0], [6, 7.0], [5, 14.142135623730951], [2, 22.360679774997898], [1, 33.54101966249684]]

さてここまで行ってきた分析は一見正しいように思われる。

しかし数値を変えることなく、単位を変えることで上記がおかしいことがわかる。例えば身長をcmではなくmで表記して計算する。

Name Train-1 Train-2 Train-3 Train-4 Train-5 Train-6
Test-7 20.0 15.0 5.0 10.9 15.0 12.0
Test-8 5.0 10.0 20.0 15.0 10.0 13.0
Test-9 15.0 10.0 0.0 5.0 10.0 7.0
  • Test-7
    近しいトレーニングデータはTrain-3、Train-4、Train-6。
    男性割合は\frac{1}{3}女性割合は\frac{2}{3}
    =>以上より女性
  • Test-8
    近しいトレーニングデータはTrain-1、Train-2、Train-5。
    男性割合は\frac{2}{3}女性割合は\frac{1}{3}
    =>以上より男性
  • Test-9
    近しいトレーニングデータはTrain-3、Train-4、Train-6。
    男性割合は\frac{1}{3}女性割合は\frac{2}{3}
    =>以上より女性

上記から最初の計算結果のTest-7と分析の結果が変わることがわかった。これは身長の幅が小さくなったことによる影響だ。

以上からデータ幅がより大きい特徴量が、相対的に分析結果に大きい影響力を持つ(≒ユークリッド距離において大きい割合を占める) ため、異なる属性の特徴量を前処理せずに分析に用いるのは不適切であることがわかった。

そこでデータの最低値、最高値を設定してやり0~1に丸め込む(縮尺する)ことでデータを前処理し、上の問題を解決することを試みる。
※説明を簡略化するために本記事では上記のような手法を採るが、正規分布に当てはめて標準化(standardization)してやるなどして前処理を行っても良い。

今回はトレーニングデータにおける身長の最低値150cmを0に、最高値の180cmを1にする。同様に体重は最低値45kg、最高値を60kgに設定する。

トレーニングデータは下記のようになる。

Name 特徴量1(身長) 特徴量2(体重) ラベル(性別)
Train-1 1.000 1.000
Train-2 0.666 0.666
Train-3 0.000 0.000
Train-4 0.000 0.333
Train-5 0.333 0.666
Train-6 0.000 0.466

テストデータは下記のようになる。

Name 特徴量1(身長) 特徴量2(体重) ラベル(性別)
Test-7 0.833 -0.333
Test-8 0.166 1.333
Test-9 0.000 0.000

上のデータを用いて再度、KNNの計算結果を確認する。

結果は下記のようになった。

Name Train-1 Train-2 Train-3 Train-4 Train-5 Train-6
Test-7 1.343 1.013 0.897 1.067 1.118 1.155
Test-8 0.897 0.833 1.343 1.013 0.687 0.882
Test-9 1.414 0.942 0.000 0.333 0.745 0.466

※少数第3位以下は切り捨て

  • Test-7
    近しいトレーニングデータはTrain-2、Train-1、Train-5。
    男性割合は\frac{2}{3}女性割合は\frac{1}{3}
    =>以上より男性
  • Test-8
    近しいトレーニングデータはTrain-5、Train-2、Train-6。
    男性割合は\frac{1}{3}女性割合は\frac{2}{3}
    =>以上より女性
  • Test-9
    近しいトレーニングデータはTrain-3、Train-4、Train-6。
    男性割合は\frac{1}{3}女性割合は\frac{2}{3}
    =>以上より女性

上記より最初の計算結果からTest-8の結果が変わることがわかった。これは身長と体重のデータ幅が0~1に丸め込みされたことにより、予測に対する身長の相対的な影響力が下がり、体重の影響力が上がったためだ。

散布図のグラフも下記のようになる。

2.2.scikit-learnでやってみる

scikit-learnとはPythonのオープンソース機械学習ライブラリだ。[3]
上記までの分析をscikit-learnを用いて再度行う。なお、ここではscikit-learnの詳細な説明は省略するため参考程度に確認して欲しい。(KNNをscikit-learnで実際に用いて分析を行う解説記事は別途執筆予定)

scikit-learn
from sklearn.neighbors import KNeighborsClassifier
# index-0:height, index-1:weight
X_train = [
        [180, 60],
        [170, 55],
        [150, 45],
        [150, 50],
        [160, 55],
        [150, 52],
        ]
# Male:0, Female:1
y_train = [0, 1, 1, 0, 0, 1]
if __name__ == "__main__":
    # n_neighbors is the number of traning data to `look at`
    neigh = KNeighborsClassifier(n_neighbors=3)
    neigh.fit(X_train, y_train)
    # test data
    X_test = [
            [175, 40],
            [155, 75],
            [150, 45]
            ]
    print(neigh.predict(X_test))
--------------
[0 1 1]

3.KNNの動作を数式で表現する

Section1の「KNNとはなにか」で述べたようにKNNは「テストデータの入力から最も近しい\bf{k}個のポイントを見て、指定したラベルと同値となるのが何個かを数え上げ、その割合を推定値として返す。

これを数式で表すと以下のようになる。

p(y=c|\bf{x}, \mathcal{D},K) = \frac{1}{K} \sum_{i\in{N_K(\bf{x}, \mathcal{D})}}^n \mathbb{I}(yi = c)

s.t.(ここで)、N_K(\bf{x}, \mathcal{D})はデータセット\mathcal{D}の中で特徴量\bf{x}を持つテストケースに最も近いK個のトレーニングデータのインデックスの集合である。

また、インジケーター関数\mathbb{I}は下記のように定義される。

ただしe関係式とする。

\mathbb{I}(e) = \begin{cases} 1 \ \rm{if} \ e \ \rm{is} \ true \\ 0 \ \rm{if} \ e \ \rm{is} \ false \end{cases}

まとめ

今回はKNNというモデルの背景の考え方をまとめた。続いてKNNが実際にどういった分析を行っているのかを実際に計算し検証した。またKNNをより厳密に捉えるために数式で表した。

一連の検証の結果、KNNに用いデータの特徴量を適切に前処理してやらないと、分析結果にデータ幅の大きい特徴量が過度に影響することがわかった。そこで今回はデータ前処理のために、特徴量の最大値を1に、最小値を0に縮尺することで対応した。

次の記事では、KNNをscikit-learnを用いてMNISTの分析とその結果から得られた考察についてまとめることでさらに深くKNNについて学んでいく。

参考文献

[1]: Fix, Evelyn; Hodges, Joseph L. (1951). Discriminatory Analysis. Nonparametric Discrimination: Consistency Propertie

[2]:Toyama, Jun, Mineichi Kudo, and Hideyuki Imai. "Probably correct k-nearest neighbor search in high dimensions." Pattern Recognition 43.4 (2010): 1361-1372.

[3]:Fabian Pedregosa; Gaël Varoquaux; Alexandre Gramfort; Vincent Michel; Bertrand Thirion; Olivier Grisel; Mathieu Blondel; Peter Prettenhofer; Ron Weiss; Vincent Dubourg; Jake Vanderplas; Alexandre Passos; David Cournapeau

Machine Learning: A Probabilistic Perspective (Adaptive Computation and Machine Learning series) Kevin P. Murphy

GitHubで編集を提案

Discussion

ログインするとコメントできます