[機械学習]KNNを用いた手書き数字(MNIST)の分類とエラー分析
概要
K nearest neighbor(KNN)とは、テストデータに対してユークリッド距離が近い
このKNNを用いて手書き数字のデータセットとして有名なMNISTの画像データが、0~9のいずれの数にあたるのかの予測を行った。
予測結果を確認した所、ラベル分類の正答率は97.05%であった。
予測結果を確認後、不正解となったテストデータにはどんな特徴があるのかを探るためにエラー分析を行った。
エラー分析にあたって、まず(KNNによる分類で)正解・不正解となったテストデータそれぞれと
作成したヒストグラムによる分布を確認した結果、(トレーニングデータとの)ユークリッド距離が大きいテストデータほど不正解となりやすいことがわかった。
次に混同行列を用いて、「どういった手書き数字が間違われやすいのか」また「どういったラベルであると誤り予測されやすいのか」について確認した。
その結果、7が1と、4が9と、8が3と間違われることが多いことがわかった。
本記事で行った分析処理は、こちらのGitHubのリポジトリに上げたコードで行っているので必要に応じて参考にしていただきたい。
はじめに
前提知識
K nearest neighbor
K nearest neighbor(以下、KNNと呼ぶ)はテストデータが入力されたときにそれに近い
筆者の『KNN(K近傍法)の仕組み』に詳しい解説があるので参考にしていただきたい。
MNIST
※上図はMNISTデータより筆者作成
MNIST(Modified National Institute of Standards)は0~9の手書き数字の画像データセットだ。
それぞれのデータは28x28ピクセルのグレースケール画像となっている。
トレーニング、バリデーション、テスト
教師あり学習ではモデル学習用のトレーニングデータ、モデル評価用のテストデータを用意する。
また最良のスコアを出すパラメータの選択のため、トレーニングデータを分割して作成したバリデーションデータを用いることがある。
使用ライブラリ
今回使用するPythonライブラリは下記だ。
scikit_learn==0.24.1
numpy==1.20.2
matplotlib==3.4.1
Pillow==8.2.0
scikit-learnはKNNによる学習・予測を行うために、NumPyは数値計算に、matplotlibはグラフ描画のために用いた。
また、Pillowは本記事で使用する画像を作成・加工するために使用した。
1. KNNを用いたMNIST分類
KNNを用いて、MNISTのテストデータセット(手書き数字画像)のラベル(0~9)を予測した。
このラベル予測のことを本記事では以後、MNIST分類と呼ぶことにする。
MNISTの取得とトレーニングデータ・テストデータへの分割
MNISTデータセットを提供するサーバーは複数用意されており、サーバーごとに多少のバリエーションがある。
今回はOpenMLが提供しているmnist_784を使用した。
mnist_784のデータの特徴量は784(28x28)ピクセルそれぞれの光度(最も弱い黒から最も強い白までの間の灰色の明暗)の値で、0~255の256段階で表される。
取得したmnist_784を前から順に最初の60,000件をトレーニングデータ、残りの10,000件をテストデータとして分割した。
前から順にトレーニングデータ・テストデータを分割した理由は、mnist_784が元々のMNISTをシャッフルしたデータセットだからだ。
k 数の選択
KNNモデルのパラメータ KNNモデルのパラメーター
以後、本記事ではKNNが取得する
まずトレーニングセットを4:1で分割した。大きい方(4,5000件)をモデルトレーニング用のデータセットとし、もう一方をパラメータ評価用のバリデーションセットとする。
次に
そうして算出された予測スコアを比較することで、最適なパラメーター
その結果、最適なパラメータ
モデルの学習と予測
その学習済みのモデルを用いて、テストデータをMNIST分類させた。
結果、ラベル予測の正答率は97.05%となった。
2. MNIST分類のエラー分析
MNIST分類の際に生じたエラー(誤ったラベル予測)の分析のために(1)正解・不正解テストデータそれぞれのユークリッド距離の分布の比較と(2)混同行列(Confusion Matrix)による頻出エラーの確認を行なう。
ヒストグラムの作成と分布の確認
先ほどMNIST分類したテストデータを正解データと不正解データに分け、それぞれのヒストグラムを作成した。
ヒストグラムの横軸は特徴量空間でのユークリッド距離とし、縦軸はユークリッド距離ごとのデータ数とした。
※上図が正解データ、下図が不正解データの分布
正解したデータの分布は1100~1200に、不正解データ分布には1600~1700に、山(データ数が最も多い箇所)が出来ている。
また不正解データの分布は正解したデータの分布に比べ、全体的に見ても右側に偏っている。
以上からneighborまでのユークリッド距離が大きいデータほど不正解になりやすいということがわかった。
混同行列の作成と頻出するエラーの組み合わせの確認
「どういった手書き数字が間違われやすいのか」また「どういったラベルであると誤り予測されやすいのか」を、混同行列(Confusion Matrix)を用いて確認した。
混同行列はMNIST分類を行った際に、テストデータの「正解のラベル」を縦軸に「予測ラベル」を横軸に置き、その組み合わせごとのデータ数を行列にして作成した。
なお、データ数についてはそれぞれの「正解のラベル」の総数(下記の混同行列で言えば各行の合計)を1として、正規化した。
またデータ数の大小を視覚的に確認するために、混同行列をヒートマップ化した。
「正解ラベル-予測ラベル」のうち、上記の混同行列で0.015以上となる「7-1」、「4-9」、「8-3」の組み合わせを頻出エラーと呼ぶことにする。
この頻出エラーとそのneighborの画像3枚を確認した。
上記の画像を筆者の目視で確認した結果、人間の目で見ても(1)ラベル(数字)を判断出来ないもの(上記で赤星を付けた画像)もあるが、(2)充分にラベル判断出来るデータも含まれていることがわかった。
考察
上記で行なってきたMNIST分類及びエラー分析を踏まえ、モデル改善の方向性を2つ示す。
第1にneighborとのユークリッド距離が遠いテストデータ(外れ値)に対して対策を取ることである。
具体的な対策としては「外れ値検知をして(書き手に)データの再筆依頼を促す」、「外れ値に対応したトレーニングデータを増やす」、などが挙げられるだろう。
こうした外れ値に対する対策によりモデル精度の改善が期待される。
第2に特徴量の項目を増やしそれをモデルの学習に取り込めるようにすることである。これにはモデルに関わるアルゴリズムの修正も含まれる。
「手書き画像はどの数字なのか」について人間とモデルの出す結論は同じであるのが実運用上望ましい。
そのことを考えれば、画像データは人間の目で見て、ラベル判断がつくものであるべきだ。
例えば、下記のようなピクセルをシャッフルした画像は人間の目で見て9と判断することは出来ない。
しかしKNNのアルゴリズムにおいてはピクセルをシャッフルした画像と元の画像の2つは同一に扱われる。
これは、画像データの1つ1つのピクセルの座標情報そのものをモデルの学習や予測に用いていないためである。
※右は左側の画像のピクセルをシャッフルしたもの
そのため、本記事で行なってきたモデルの学習・予測のアルゴリズムでは画像データの特徴量を(全て同じ規則で)シャッフルしてもスコアの精度に影響が出ることはない。
※実際に一定の規則に基づいてシャッフルしたデータを学習・予測に用いた結果、正答率はシャッフルする前と同じく97.05%となった。
「画像データは人間の目で見てラベル判断がつくべきである」という実運用上の要求から考えても、ピクセルの相互の位置関係を学習・予測に組み込むようにモデルのアルゴリズムを修正することでより人間の判断に近いモデルを構築するべきだと考える。
Discussion