☀️

ランダムフォレスト回帰【機械学習】

2024/11/15に公開

ランダムフォレストの回帰タスクについてまとめました。

https://qiita.com/0w0_kaomoji_/items/ef1ee13e62adf4fa4e44

1. ランダムフォレストとは

  • 複数の決定木(というモデル)を組み合わせるアンサンブル手法を用いたモデルです。

  • クラス分類回帰に用いられます。

2. メリット

  • アンサンブル手法なので、決定木よりも外れ値に対して頑健です。またバギングにより、モデルのバイアスは変わらないまま分散が小さくなります。

    image 2.png

  • 決定木を用いるので、変数変換に対しても頑健です。

  • 欠測値を推定してくれるので、埋める必要がありません。

  • 分岐の繰り返しによって説明変数の相互作用を考慮するので、明示的に与える必要がありません。

3. アルゴリズム

  1. 各決定木で、データをブートストラップします。

  2. 各ノードで、以下の操作を情報利得が変化しなくなるまで、繰り返します。

    1. 分岐に用いる説明変数の候補を選択します。
    2. 各ノードで、分岐条件を決定します。
      • 情報利得が最大になるような分岐条件を求めます。
      • つまり、目的は子ノードの不純度の総和を最小にすることです。
  3. 全ての木における予測値を平均します。

    • 最終的な予測値\hat{y}^{(i)}は以下のように表されます。

      \hat{y}^{(i)}=\frac{1}{M}\sum_{j=1}^{M}\hat{y}_j^{(i)}

      M:木の本数

      \hat{y}_j^{(i)}:木jにおける、y^{(i)}の予測値

4. 他のモデルとの違い

  • GBDT(勾配ブースティング決定木)との違い
    • GBDTでは決定木を直列に並べるのに対して、ランダムフォレストでは並列に並べます。そのため、決定木の本数を増やすことで精度が悪くなることはありません。
    • アンサンブル手法としてGBDTではブースティングを用いるのに対し、ランダムフォレストではバギングを用いています。
  • ランダムフォレストクラスタリングとの違い
    • 不純度の指標として、クラスタリングではジニ不純度やエントロピーを用いるのに対して、回帰では平均二乗誤差を使用します。
    • クラスタリングでは予測値は多数決により決定されますが、回帰では平均で計算されます。

5. サンプルコード

  • ライブラリはScikit-learnを使用します

    • Section 3 アルゴリズムの2-aにあるように、説明変数の候補を選択するステップがありますが、Scikit-learnのデフォルトの候補はすべての説明変数 となっています。つまり、全ての説明変数が候補として選択される確率は1です。
    • 図2からもわかるように、Scikit-learnでは決定木は2分木で実装されています。
  • コードは以下のようになります。

    # dfについては補足にある図1のコードを参照
    X = df.iloc[:, :-1].values  # 特徴量データ(最後の列を除く)
    y = df['MEDV'].values       # 目的変数(MEDV)
    from sklearn.model_selection import train_test_split
    
    # データ分割(40%をテストデータに)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=1)
    
    from sklearn.ensemble import RandomForestRegressor as RFR
    
    # ランダムフォレストモデルの設定
    forest = RFR(criterion='squared_error', # 不純度基準
                 random_state=1,            # 乱数シード
                 n_jobs=-1)                 # 並列処理
    
    forest.fit(X_train, y_train)  # モデルの学習
    
    # 予測値の取得
    y_train_pred = forest.predict(X_train)
    y_test_pred = forest.predict(X_test)
    
    from sklearn.metrics import mean_squared_error, r2_score
    
    # MSEとR^2スコアの表示
    print('MSE train: %.3f, test: %.3f' % (mean_squared_error(y_train, y_train_pred),
                                           mean_squared_error(y_test, y_test_pred)))
    print('R^2 train: %.3f, test: %.3f' % (r2_score(y_train, y_train_pred),
                                           r2_score(y_test, y_test_pred)))
    
    import matplotlib.pyplot as plt
    
    # 残差プロット
    plt.scatter(y_train, y_train_pred - y_train, 
                c='steelblue', edgecolor='white', marker='o', 
                s=35, alpha=0.9, label='training_data')
    plt.scatter(y_test, y_test_pred - y_test, 
                c='palevioletred', edgecolor='white', marker='s', 
                s=15, alpha=0.9, label='test_data')
    plt.xlabel('True values')
    plt.ylabel('Residuals')
    plt.legend(loc='upper right')
    plt.title('Residuals')
    plt.hlines(y=0, xmin=-10, xmax=50, lw=2, color='black')
    plt.tight_layout()
    plt.show()
    
    # 特徴量の重要度
    feature_importances = forest.feature_importances_
    
    import numpy as np
    
    # 特徴量重要度の棒グラフ
    plt.figure(figsize=(10, 5))
    y = feature_importances
    x = np.arange(len(y))
    plt.bar(x, y, align="center")
    plt.xticks(x, df.columns[:-1])
    plt.xlabel('Features')
    plt.ylabel('Importance')
    plt.title('Feature Importance')
    plt.tight_layout()
    plt.show()
    
    • 出力

      MSE train: 5.544, test: 15.618
      R^2 train: 0.927, test: 0.844

      image 3.png

      image 4.png

6. 結果の解釈

  • MSE
    • 過学習する傾向にあることがわかります。
  • 決定係数
    • モデルはかなり良く当てはまっていることがわかります。
  • 残差プロット
    • 明らかに負の相関関係が見えることから、モデルがデータの情報をとらえきれていないことがわかります。
  • 各説明変数の重要度
    • 特に“RM”と"LSTAT"という説明変数が、モデルに大きな影響を与えていることがわかります。

参考文献

  • Raschka, Sebastian, Vahid Mirjalili, 株式会社クイープ, and 福島真太朗. 2018. [第2版]Python機械学習プログラミング 達人データサイエンティストによる理論と実践. インプレス.
  • 門脇大輔, 阪田隆司, 保坂桂佑, and 平松雄司. 2019. Kaggleで勝つデータ分析の技術.

https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html

https://scikit-learn.org/stable/auto_examples/ensemble/plot_forest_hist_grad_boosting_comparison.html#sphx-glr-auto-examples-ensemble-plot-forest-hist-grad-boosting-comparison-py

https://ja.wikipedia.org/wiki/バギング

https://bellcurve.jp/statistics/course/9706.html?srsltid=AfmBOopgESK3d1ZiHsbUDri-qsqmjQt-QqsOj4eOQLeWCv7Q2LhShub1

https://funatsu-lab.github.io/open-course-ware/machine-learning/random-forest/

https://dropout009.hatenablog.com/entry/2021/07/26/193907

  • アンサンブルとバイアス、分散との関係を説明している

https://amalog.hateblo.jp/entry/decision-tree-scaling

  • 変数変換に対する頑健性を検証している

https://qiita.com/g-k/items/e7dcc4d2b057dada405c

  • ランダムフォレストの欠損値補完を説明している

https://blog.data-hacker.net/2020/07/randomforest.html

  • 相互作用に関する検証をしている

補足

  • 図1のコード

    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.tree import DecisionTreeRegressor
    def lin_regplot(X_test,X, y, model):
        plt.scatter(X, y, c='steelblue', edgecolor='white', s=70,label="Data")
        y_pred=model.predict(X_test)
        plt.plot(X_test, y_pred, color='orange', lw=2, label="Prediction")
        for i in range(1, len(y_pred)):
          if y_pred[i] != y_pred[i - 1]:  
              plt.axvline(x=X_test[i], color='green', linestyle='--', linewidth=1)
        
        return 
    
    df = pd.read_csv('https://raw.githubusercontent.com/rasbt/'
                     'python-machine-learning-book-2nd-edition'
                     '/master/code/ch10/housing.data.txt',
                     header=None,
                     sep='\s+')
    
    df.columns = ['CRIM', 'ZN', 'INDUS', 'CHAS', 
                  'NOX', 'RM', 'AGE', 'DIS', 'RAD', 
                  'TAX', 'PTRATIO', 'B', 'LSTAT', 'MEDV']
    
    df=df.sample(100)
    
    X = df[['LSTAT']].values
    y = df['MEDV'].values
    
    tree = DecisionTreeRegressor(max_depth=2)
    tree.fit(X, y)
    
    X_test = np.arange(0, 40, 0.01)[:, np.newaxis]
    
    lin_regplot(X_test,X,y, tree)
    plt.xlabel('% lower status of the population [LSTAT]')
    plt.ylabel('Price in $1000s [MEDV]')
    plt.legend()
    plt.show()
    
  • 図2のコード

    from sklearn.tree import export_graphviz
    from graphviz import Source
    from sklearn.tree import plot_tree
    dot_data = export_graphviz(tree, 
                               filled=True, 
                               rounded=True,
                               feature_names=['LSTAT'], 
                               out_file=None)
    graph = Source(dot_data) 
    graph.format = 'png'
    graph.render('tree')
    

Discussion