🐥

決定木について簡単に

2024/10/21に公開

以下データがあるとする。あたえられたクラス(リンゴ、チェリー、バナナ)
その特徴を表現すると
リンゴ 色=赤 大きさ:大きい
チェリー=色:赤 大きさ:小さい
バナナ=色:黄 大きさ:大きい
さて、この死ぬほどわかりやすいデータたちをどうやったら確実にわけられるだろうか。
それは以下である。

ここにわかりやすく、リンゴ:10 チェリー10 バナナ:10 あるとする。このうち80パーセントを学習にあてるとする。
これをランダムだとするなら、10×0.8=24個のデータを乱雑に抽出する。
これをいわゆる教師データという。この教師データを学習し、テストでどれくらいのスコアがだせるかが基本になる。
そして残りの20%つまり6個のデータがテストデータになる。

以下は学習後のフローだ。つまりテストのこと。

ちなみにソースコードは以下である。

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import confusion_matrix, accuracy_score
from sklearn.tree import export_graphviz
import graphviz

# サンプルデータ
data = {
'色': ['赤', '赤', '赤', '黄', '赤', '黄', '赤', '赤', '黄', '黄', '赤', '黄', '赤', '赤', '黄', '赤', '赤', '赤', '黄', '赤', '赤', '黄', '赤', '赤', '黄', '赤', '黄', '赤', '黄'],
'サイズ': ['大', '大', '小', '大', '小', '大', '小', '大', '大', '大', '小', '大', '小', '小', '大', '小', '小', '大', '大', '大', '小', '大', '大', '小', '大', '大', '大', '小', '大'],
'フルーツ': ['リンゴ', 'リンゴ', 'チェリー', 'バナナ', 'チェリー', 'バナナ', 'チェリー', 'リンゴ', 'バナナ', 'リンゴ', 'チェリー', 'バナナ', 'リンゴ', 'チェリー', 'バナナ', 'リンゴ', 'チェリー', 'リンゴ', 'バナナ', 'リンゴ', 'チェリー', 'バナナ', 'リンゴ', 'チェリー', 'バナナ', 'リンゴ', 'バナナ', 'チェリー', 'リンゴ']

}
# 正しいデータを表示する。
dataset = pd.DataFrame(data)
# print(dataset)


# 特徴量とラベルに分割
X = pd.get_dummies(dataset.iloc[:, :-1])
y = dataset.iloc[:, -1].values

# 特徴量 色_赤 色_黄色 サイズ_大 サイズ_小
# print(X)
# 目的変数(フルーツ各種)
# print(y)

# トレーニングデータとテストデータに分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

# 決定木モデルのトレーニング
classifier = DecisionTreeClassifier(criterion='gini', max_depth=2,random_state=0)
classifier.fit(X_train, y_train)
# テストデータで予測
y_pred = classifier.predict(X_test)

# 混同行列の作成
cm = confusion_matrix(y_test, y_pred)
print(cm)

# 精度の計算
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy: .2f}')


教師データの特徴を考え、上記のツリー上のフローチャートを作っていくのがまさしく決定木だ。

たとえば教師データリンゴ一つをとるなら、そのリンゴの特徴である赤を取ってみる。
その赤という特徴を持つフルーツはリンゴだけでなくサクランボも残る。
すると、データは「不純」なものとなる。それゆえ、純粋なデータになるようノードを増やさなければならない。
この不純なデータを解析することを、「ジニ不純度」という。
ジニ不純度は、あるノードがどれだけ不純であるか(異なるクラスが混在しているか)を示す指標です。
つまり赤かどうかの判定後のノードというのは2:2つまり、半々で別のクラスが存在してるので不純といえるだろう。

それを解析する数式は以下である。

ジニ不純度の計算

ジニ不純度 ( G ) は次の式で表される。:

G = 1 - \sum_{k=1}^{n} p_k^2

数式アレルギーの人にもう少し砕いて説明すると、

n=\text{クラス数(つまり、リンゴ、サクランボ、バナナのこと)}
p_k = \frac{各フルーツの数}{ノード分割後の教師データ全体の数}

数式で考えるとわかりづらいが、実際に数字をいれるとバカみたいにわかりやすい

G = 1 - (p_1^2 + p_2^2)

ノード時点でフルーツの数は4内訳はサクランボ2個、リンゴ2個なので、

リンゴもチェリーも

\frac{4}{2} = \frac{1}{2} = 0.5

なので数式に当てはめると、

G = 1 - (0.5^2 + 0.5^2)
G = 1 - (0.25 + 0.25) = 1 - 0.5 = 0.5

0.5だと無論不純度がたかいので、ノードをかませる。
それが「大きい?」という分岐だ。

このノードを実行すると、サクランボとリンゴがきれいに分かれるので、うまくいきそうだ。
手前みそだがもう一度ジニ不純度で考えてみよう。

分岐後のアイテム数は2個そしてリンゴも2個なので、\\ \frac{2}{2} =1\\ G= 1-1 = 0

よってこのノードの純度は高いので、このアルゴリズムを実行した場合、好スコアが叩き出せると思う。
上記の実行後精度は1.00なので、100パーセントの精度でフルーツを仕分けることができる!!

そのほかエントロピーを使う計算などがあるが、長丁場になるので次回以降にしようと思う。
決定木を馬鹿みたいに簡単に理解するためのガイドなので、割愛させていただく。

Discussion