😎

勾配ブースティング決定木(GBDT)って実際どう動いているの?

に公開

勾配ブースティング決定木(GBDT)って実際どう動いているの?


勾配ブースティング決定木は実際どう動いているのか?
(数値を用いて簡易的にシミュレーション)

0. 前提

  • 使用するデータ

    説明変数 x 目的変数 y
    1 3
    2 2
    3 4
    4 5
  • タスク
    説明変数 x から目的変数 y を予測する回帰タスク


1. 全データの目的変数 yの平均を算出し、それを予測値 \hat{y}_0と置き換える

  • \hat{y}_0 = \frac{3 + 2 + 4 + 5}{4} = 3.5

    説明変数 x 目的変数 y 予測値 \hat{y}_0
    1 3 3.5
    2 2 3.5
    3 4 3.5
    4 5 3.5

2. 誤差 r_1y - \hat{y}_0)を算出する

  • r_1 = y - \hat{y}_0

    説明変数 x 目的変数 y 予測値 \hat{y}_0 誤差 r_1
    1 3 3.5 -0.5
    2 2 3.5 -1.5
    3 4 3.5 0.5
    4 5 3.5 1.5

3. 誤差を予測(説明)する決定木を作成し、誤差予測値 \hat{r}_1を算出する(弱学習器の作成)

  • 決定木
  • データ
    説明変数 x 目的変数 y 予測値 \hat{y}_0 誤差 r_1 誤差予測値 \hat{r}_1
    1 3 3.5 -0.5 -1.0
    2 2 3.5 -1.5 -1.0
    3 4 3.5 0.5 1.0
    4 5 3.5 1.5 1.0

4. 新たな予測値 \hat{y}_1を算出する ※学習率 nは0.5とする

  • \hat{y}_1 = \hat{y}_0 + n \cdot \hat{r}_1

    説明変数 x 目的変数 y 予測値 \hat{y}_0 誤差予測値 \hat{r}_1 予測値\hat{y}_1
    1 3 3.5 -1.0 3.5+0.5\cdot-1.0=3.0
    2 2 3.5 -1.0 3.5+0.5\cdot-1.0=3.0
    3 4 3.5 1.0 3.5+0.5\cdot1.0=4.0
    4 5 3.5 1.0 3.5+0.5\cdot1.0=4.0

5. 新たな誤差 r_2y - \hat{y}_1)を算出する

  • r_2 = y - \hat{y}_1

    説明変数 x 目的変数 y 予測値\hat{y}_1 誤差 r_2
    1 3 3.0 0.0
    2 2 3.0 -1.0
    3 4 4.0 0.0
    4 5 4.0 1.0

6. 誤差を予測(説明)する決定木を作成し、誤差予測値 \hat{r}_2を算出する(弱学習器の作成)

  • 決定木
  • データ
    説明変数 x 目的変数 y 予測値\hat{y}_1 誤差 r_2 誤差予測値 \hat{r}_2
    1 3 3.0 0.0 -0.5
    2 2 3.0 -1.0 -0.5
    3 4 4.0 0.0 0.5
    4 5 4.0 1.0 0.5

7. 新たな予測値 \hat{y}_2を算出する ※学習率 nは0.5とする

  • \hat{y}_2 = \hat{y}_1 + n \cdot \hat{r}_2

    説明変数 x 目的変数 y 予測値\hat{y}_1 誤差予測値 \hat{r}_2 予測値\hat{y}_2
    1 3 3.0 -0.5 3.0+0.5\cdot-0.5=2.75
    2 2 3.0 -0.5 3.0+0.5\cdot-0.5=2.75
    3 4 4.0 0.5 4.0+0.5\cdot0.5=4.25
    4 5 4.0 0.5 4.0+0.5\cdot0.5=4.25

8. その後も、「誤差算出 --> 誤差予測値算出 --> 予測値算出」を繰り返す



終わりに

上記例により、予測値が目的変数に近づいていく様子を理解できたら幸いです!

Discussion