決定木について簡単に
以下データがあるとする。あたえられたクラス(リンゴ、チェリー、バナナ)
その特徴を表現すると
リンゴ 色=赤 大きさ:大きい
チェリー=色:赤 大きさ:小さい
バナナ=色:黄 大きさ:大きい
さて、この死ぬほどわかりやすいデータたちをどうやったら確実にわけられるだろうか。
それは以下である。
ここにわかりやすく、リンゴ:10 チェリー10 バナナ:10 あるとする。このうち80パーセントを学習にあてるとする。
これをランダムだとするなら、10×0.8=24個のデータを乱雑に抽出する。
これをいわゆる教師データという。この教師データを学習し、テストでどれくらいのスコアがだせるかが基本になる。
そして残りの20%つまり6個のデータがテストデータになる。
以下は学習後のフローだ。つまりテストのこと。
ちなみにソースコードは以下である。
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import confusion_matrix, accuracy_score
from sklearn.tree import export_graphviz
import graphviz
# サンプルデータ
data = {
'色': ['赤', '赤', '赤', '黄', '赤', '黄', '赤', '赤', '黄', '黄', '赤', '黄', '赤', '赤', '黄', '赤', '赤', '赤', '黄', '赤', '赤', '黄', '赤', '赤', '黄', '赤', '黄', '赤', '黄'],
'サイズ': ['大', '大', '小', '大', '小', '大', '小', '大', '大', '大', '小', '大', '小', '小', '大', '小', '小', '大', '大', '大', '小', '大', '大', '小', '大', '大', '大', '小', '大'],
'フルーツ': ['リンゴ', 'リンゴ', 'チェリー', 'バナナ', 'チェリー', 'バナナ', 'チェリー', 'リンゴ', 'バナナ', 'リンゴ', 'チェリー', 'バナナ', 'リンゴ', 'チェリー', 'バナナ', 'リンゴ', 'チェリー', 'リンゴ', 'バナナ', 'リンゴ', 'チェリー', 'バナナ', 'リンゴ', 'チェリー', 'バナナ', 'リンゴ', 'バナナ', 'チェリー', 'リンゴ']
}
# 正しいデータを表示する。
dataset = pd.DataFrame(data)
# print(dataset)
# 特徴量とラベルに分割
X = pd.get_dummies(dataset.iloc[:, :-1])
y = dataset.iloc[:, -1].values
# 特徴量 色_赤 色_黄色 サイズ_大 サイズ_小
# print(X)
# 目的変数(フルーツ各種)
# print(y)
# トレーニングデータとテストデータに分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
# 決定木モデルのトレーニング
classifier = DecisionTreeClassifier(criterion='gini', max_depth=2,random_state=0)
classifier.fit(X_train, y_train)
# テストデータで予測
y_pred = classifier.predict(X_test)
# 混同行列の作成
cm = confusion_matrix(y_test, y_pred)
print(cm)
# 精度の計算
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy: .2f}')
教師データの特徴を考え、上記のツリー上のフローチャートを作っていくのがまさしく決定木だ。
たとえば教師データリンゴ一つをとるなら、そのリンゴの特徴である赤を取ってみる。
その赤という特徴を持つフルーツはリンゴだけでなくサクランボも残る。
すると、データは「不純」なものとなる。それゆえ、純粋なデータになるようノードを増やさなければならない。
この不純なデータを解析することを、「ジニ不純度」という。
ジニ不純度は、あるノードがどれだけ不純であるか(異なるクラスが混在しているか)を示す指標です。
つまり赤かどうかの判定後のノードというのは2:2つまり、半々で別のクラスが存在してるので不純といえるだろう。
それを解析する数式は以下である。
ジニ不純度の計算
ジニ不純度 ( G ) は次の式で表される。:
数式アレルギーの人にもう少し砕いて説明すると、
数式で考えるとわかりづらいが、実際に数字をいれるとバカみたいにわかりやすい
ノード時点でフルーツの数は4内訳はサクランボ2個、リンゴ2個なので、
リンゴもチェリーも
なので数式に当てはめると、
0.5だと無論不純度がたかいので、ノードをかませる。
それが「大きい?」という分岐だ。
このノードを実行すると、サクランボとリンゴがきれいに分かれるので、うまくいきそうだ。
手前みそだがもう一度ジニ不純度で考えてみよう。
よってこのノードの純度は高いので、このアルゴリズムを実行した場合、好スコアが叩き出せると思う。
上記の実行後精度は1.00なので、100パーセントの精度でフルーツを仕分けることができる!!
そのほかエントロピーを使う計算などがあるが、長丁場になるので次回以降にしようと思う。
決定木を馬鹿みたいに簡単に理解するためのガイドなので、割愛させていただく。
Discussion