☀️
ランダムフォレスト回帰【機械学習】
ランダムフォレストの回帰タスクについてまとめました。
1. ランダムフォレストとは
-
複数の決定木(というモデル)を組み合わせるアンサンブル手法を用いたモデルです。
-
クラス分類や回帰に用いられます。
2. メリット
-
アンサンブル手法なので、決定木よりも外れ値に対して頑健です。またバギングにより、モデルのバイアスは変わらないまま分散が小さくなります。
-
決定木を用いるので、変数変換に対しても頑健です。
-
欠測値を推定してくれるので、埋める必要がありません。
-
分岐の繰り返しによって説明変数の相互作用を考慮するので、明示的に与える必要がありません。
3. アルゴリズム
-
各決定木で、データをブートストラップします。
-
各ノードで、以下の操作を情報利得が変化しなくなるまで、繰り返します。
- 分岐に用いる説明変数の候補を選択します。
- 各ノードで、分岐条件を決定します。
- 情報利得が最大になるような分岐条件を求めます。
- つまり、目的は子ノードの不純度の総和を最小にすることです。
-
全ての木における予測値を平均します。
-
最終的な予測値
は以下のように表されます。\hat{y}^{(i)} \hat{y}^{(i)}=\frac{1}{M}\sum_{j=1}^{M}\hat{y}_j^{(i)} :木の本数M :木\hat{y}_j^{(i)} における、j の予測値y^{(i)}
-
4. 他のモデルとの違い
-
GBDT(勾配ブースティング決定木)との違い
- GBDTでは決定木を直列に並べるのに対して、ランダムフォレストでは並列に並べます。そのため、決定木の本数を増やすことで精度が悪くなることはありません。
- アンサンブル手法としてGBDTではブースティングを用いるのに対し、ランダムフォレストではバギングを用いています。
-
ランダムフォレストクラスタリングとの違い
- 不純度の指標として、クラスタリングではジニ不純度やエントロピーを用いるのに対して、回帰では平均二乗誤差を使用します。
- クラスタリングでは予測値は多数決により決定されますが、回帰では平均で計算されます。
5. サンプルコード
-
ライブラリはScikit-learnを使用します
- Section 3 アルゴリズムの2-aにあるように、説明変数の候補を選択するステップがありますが、Scikit-learnのデフォルトの候補はすべての説明変数 となっています。つまり、全ての説明変数が候補として選択される確率は1です。
- 図2からもわかるように、Scikit-learnでは決定木は2分木で実装されています。
-
コードは以下のようになります。
# dfについては補足にある図1のコードを参照 X = df.iloc[:, :-1].values # 特徴量データ(最後の列を除く) y = df['MEDV'].values # 目的変数(MEDV) from sklearn.model_selection import train_test_split # データ分割(40%をテストデータに) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=1) from sklearn.ensemble import RandomForestRegressor as RFR # ランダムフォレストモデルの設定 forest = RFR(criterion='squared_error', # 不純度基準 random_state=1, # 乱数シード n_jobs=-1) # 並列処理 forest.fit(X_train, y_train) # モデルの学習 # 予測値の取得 y_train_pred = forest.predict(X_train) y_test_pred = forest.predict(X_test) from sklearn.metrics import mean_squared_error, r2_score # MSEとR^2スコアの表示 print('MSE train: %.3f, test: %.3f' % (mean_squared_error(y_train, y_train_pred), mean_squared_error(y_test, y_test_pred))) print('R^2 train: %.3f, test: %.3f' % (r2_score(y_train, y_train_pred), r2_score(y_test, y_test_pred))) import matplotlib.pyplot as plt # 残差プロット plt.scatter(y_train, y_train_pred - y_train, c='steelblue', edgecolor='white', marker='o', s=35, alpha=0.9, label='training_data') plt.scatter(y_test, y_test_pred - y_test, c='palevioletred', edgecolor='white', marker='s', s=15, alpha=0.9, label='test_data') plt.xlabel('True values') plt.ylabel('Residuals') plt.legend(loc='upper right') plt.title('Residuals') plt.hlines(y=0, xmin=-10, xmax=50, lw=2, color='black') plt.tight_layout() plt.show() # 特徴量の重要度 feature_importances = forest.feature_importances_ import numpy as np # 特徴量重要度の棒グラフ plt.figure(figsize=(10, 5)) y = feature_importances x = np.arange(len(y)) plt.bar(x, y, align="center") plt.xticks(x, df.columns[:-1]) plt.xlabel('Features') plt.ylabel('Importance') plt.title('Feature Importance') plt.tight_layout() plt.show()
-
出力
MSE train: 5.544, test: 15.618
R^2 train: 0.927, test: 0.844
-
6. 結果の解釈
- MSE
- 過学習する傾向にあることがわかります。
- 決定係数
- モデルはかなり良く当てはまっていることがわかります。
- 残差プロット
- 明らかに負の相関関係が見えることから、モデルがデータの情報をとらえきれていないことがわかります。
- 各説明変数の重要度
- 特に“RM”と"LSTAT"という説明変数が、モデルに大きな影響を与えていることがわかります。
7. 感想
ランダムフォレストは、その構造がシンプルであるにもかかわらず、予想以上に高い精度を示すことに感心しました。また、モデルの構築プロセスが比較的容易であることも、実務上大きな利点だと思いました。
一方で、モデルの改善方法が明確でないという課題があり、その点では、統計的モデルのほうが扱いやすそうです。
参考文献
- Raschka, Sebastian, Vahid Mirjalili, 株式会社クイープ, and 福島真太朗. 2018. [第2版]Python機械学習プログラミング 達人データサイエンティストによる理論と実践. インプレス.
- 門脇大輔, 阪田隆司, 保坂桂佑, and 平松雄司. 2019. Kaggleで勝つデータ分析の技術.
- アンサンブルとバイアス、分散との関係を説明している
- 変数変換に対する頑健性を検証している
- ランダムフォレストの欠損値補完を説明している
- 相互作用に関する検証をしている
補足
-
図1のコード
import pandas as pd import numpy as np import matplotlib.pyplot as plt from sklearn.tree import DecisionTreeRegressor def lin_regplot(X_test,X, y, model): plt.scatter(X, y, c='steelblue', edgecolor='white', s=70,label="Data") y_pred=model.predict(X_test) plt.plot(X_test, y_pred, color='orange', lw=2, label="Prediction") for i in range(1, len(y_pred)): if y_pred[i] != y_pred[i - 1]: plt.axvline(x=X_test[i], color='green', linestyle='--', linewidth=1) return df = pd.read_csv('https://raw.githubusercontent.com/rasbt/' 'python-machine-learning-book-2nd-edition' '/master/code/ch10/housing.data.txt', header=None, sep='\s+') df.columns = ['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT', 'MEDV'] df=df.sample(100) X = df[['LSTAT']].values y = df['MEDV'].values tree = DecisionTreeRegressor(max_depth=2) tree.fit(X, y) X_test = np.arange(0, 40, 0.01)[:, np.newaxis] lin_regplot(X_test,X,y, tree) plt.xlabel('% lower status of the population [LSTAT]') plt.ylabel('Price in $1000s [MEDV]') plt.legend() plt.show()
-
図2のコード
from sklearn.tree import export_graphviz from graphviz import Source from sklearn.tree import plot_tree dot_data = export_graphviz(tree, filled=True, rounded=True, feature_names=['LSTAT'], out_file=None) graph = Source(dot_data) graph.format = 'png' graph.render('tree')
Discussion