🦁

DecisionTree分類器の構造を可視化してみる

に公開

今回はscikit-learnのDecisionTree分類器の学習結果を可視化する方法をまとめてみます。

DecisionTree分類器とは?

DecisionTree分類器は決定木を利用した分類器です。構造としては二分木で、各ノードで何かしらの条件により入力データが2つに分割されます。その分割を複数段階適用することでクラス分類をすると言うものになります。DecisionTree分類器については以下をご参照ください。

https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html

早速実装!

今回はirisデータを使ってモデルを学習します。

環境構築

uvを使って以下のように環境を構築しました。

uv init iris_decision_tree_classifier_visualize -p 3.12
cd iris_decision_tree_classifier_visualize
uv add scikit-learn matplotlib

コードの実装

今回実装したコードは以下のようになっています。

visualize_dtc.py
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.datasets import load_iris

iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7)
classifier = DecisionTreeClassifier(max_depth=3)
classifier.fit(X_train, y_train)
accuracy = accuracy_score(classifier.predict(X_test), y_test)
print(f"{accuracy=}")
plot_tree(classifier)
plt.show()

まずはirisのデータを学習用とテストように分割します。

iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7)

次にモデルを学習します。sklearn.tree.DecisionTreeClassifierを利用します。モデルの学習はscikit-learnお決まりのfit関数を利用します。決定木を利用するときにはいくつかパラメータがあるのですが、今回はmax_depthを3に設定することだけ対応しました。max_depthは何段階までデータの分割をさせるかを表すパラメータであり、今回の例で言うと最大3回までデータが分割されます。

classifier = DecisionTreeClassifier(max_depth=3)
classifier.fit(X_train, y_train)
accuracy = accuracy_score(classifier.predict(X_test), y_test)
print(f"{accuracy=}")

最後に可視化の部分になります。sklearn.tree.plot_treeを利用すると内部的にmatplotlibを利用してツリー構造を可視化してくれます。

plot_tree(classifier)
plt.show()

早速実行してみる!

それではコードを早速実行してみましょう。テストデータに対する精度はおよそ98%弱であり、かなり高い精度になっています。

uv run visualizes_dct.py

# 結果
accuracy=0.9777777777777777

可視化結果ですが、実行すると以下のような表示がされます。結果をみると先ほどmax_depthに設定したように最大3回までデータbが分割されていることがわかります。それぞれの分岐点ではYes/Noでできる質問がされます。例えば一つ目の点では特徴量2について2.35以下かそうでないかでデータが分かれます。Yesだった場合はそれ以上分割がされておらず、そこに流れた全てのデータはクラス0に分類されていることがわかります。今回分割された全ての基準にてNoと回答した場合は一番右側のノードに到着し、全32サンプルはクラス2に分類されることがわかります。

まとめ

今回はDecisionTree分類器の可視化をしてみました。実際に開発するモデルはもっと複雑な深いモデルであることも多く可視化してもみやすいかというとそうでない場合もありますが、モデルの挙動を解釈するためには重要な手法ですので、ぜひ試してみてください。

Discussion