🔥

TabNetとは一体何者なのか?

2020/12/06に公開

簡単に

  • Kaggleで最近よく使われるTabnetについて、どのようなモデルか調べた。
  • Tree-basedとDNNのいいとこ取りをしたようなモデル。
  • Feature ImportanceとMaskにより結果の解釈ができる。
  • Titanicにおける精度について、LBの値ではLightGBM、NNよりもやや高い。
  • TitanicにおけるFeature Importanceの上位特徴量について、LightGBMとは異なっている。そのため、TabNetはEnsembleに有用かもしれない。

※ 2021/01/10 14:50 TabNetのコードが一部誤っていることを指摘頂き、コード修正しました。それに伴い記事の下記部分を更新しています。

  • 4 実装の際に用いたNotebook
  • 6.7 精度
  • 6.8 Feature Importance(Global interpretability)
  • 6.10 予測確率の分布

ご指摘ありがとうございます!

1 はじめに

2020/12/1に終了したKaggleコンペのMoAにて皆さん当たり前のようにTabNet というモデルを利用していました。かくいう私もTabNetを最終subのベースモデルとして利用しています。最終sub概要はこちら

また最近のテーブルデータコンペでは、TabNetが頻繁に使われている気がします。そのため、LightGBMなどと並び、テーブルコンペで今後よく使われるモデルになるのかなぁ、と思っています。このTabNet、どんなモデルなのか気になったので調べてみました。

理解不足の部分もあるので、間違いなどあったらコメントもらえると嬉しいです🙇‍♂️

2 この記事に書いてあること

  • TabNetの論文の解釈(私がある程度理解できた部分のみを記載しています)
  • TabNetの実装例(Kaggle Notebookにて記載)
  • TabNetのFeature Importance,Maskの出力例、
  • Titanicデータを用いた精度比較, Feature Importanceの比較、出力分布の比較。
    (TabNet with pretrain, TabNet without pretrain, LightGBM, NN)

3 Reference

論文
https://arxiv.org/pdf/1908.07442.pdf

AI Platform 上の TabNet: 高パフォーマンスで説明可能な表形式ラーニング
https://cloud.google.com/blog/ja/products/ai-machine-learning/ml-model-tabnet-is-easy-to-use-on-cloud-ai-platform

4 実装の際に用いたNotebook

https://www.kaggle.com/sinchir0/selfsupervisedtabnet-for-titanic

5 論文を読んでみる

5.1 論文概要

ICLR2020をrejectされた論文となります。
https://openreview.net/forum?id=BylRkAEKDH

とてもざっくり書くと、論文には下記のようなことが書かれています。

5.1.1 簡単にいうとTabNetって何?

  • 表データ向けのディープラーニングモデル
  • Tree-basedとDNNのいいとこ取りをしたようなモデル
  • Tree-based(決定木をベースにしたアルゴリズム)の解釈可能性を持ちつつ、 大きなデータセットに対してDNNのような高い性能を持つ。
  • Tree-basedのモデルと異なり、一切の特徴量生成を必要としない。

5.1.2 どんな特徴があるの?

  • 教師なしの事前学習を用いてマスクされた特徴を予測し、pre-trainを行っている。
  • Sequetial Attention(逐次注意)を用いることにより、説明可能性を持ち、またfeature selectionはinstance-wise(input毎に用いる特徴量が異なる)に行う。
  • Tabnetは2種類の解釈性を提供する。一つ目はlocal interpretabilityであり、入力特徴量の重要性とそれらがどのように結合されたかを可視化する。二つ目はglobal interpretabilityであり、各特徴量 が学習モデルに対し、どのぐらい貢献したかを示す。

5.1.3 精度は?

  • 表データに対して、MLP, LightGBM, XGBoost, CatBoostなどよりも高い性能を発揮した。

・・・文章だけだとなんだかよく分からんですね🤔 論文に色々と図が載っていたのでそちらを見てみたいと思います。

5.2 どんなFeature Selectionなの?

上図は、TabNetの特徴量選択の説明している図です。例で使われているデータセットは、Adult Census Income prediction(成人の国勢調査の収入予測)のデータであるため、最終的には収入の予測が目的となります。

「Professional occupation related(専門職関連)」、「Investment related(投資関連)」という説明が図にあることや論文の記載から予想すると、下記のようなことを行っているように思えます。

  1. 収入予測を行う際に重要となる特徴を意味のある単位で選択(図の場合は「専門職」と「投資」)
  2. 意味のある単位で選択された特徴量と目的変数に関して、相互情報量が最大になるよう特徴選択

5.3 どんなPretrain?

この図は

教師なしの事前学習を用いてマスクされた特徴を予測し、pre-trainを行っている。

を具体的に説明している図です。左図ではテーブルデータの一部をマスクし、そのマスクした部分を予測出来るように学習を進め、右図での本番学習において、予め学習させた結果を転移学習させているようです。

図の説明には、「表データは,相互依存した特徴量の列を持っており,例えば,職業から教育レベルを推測したり,関係性から性別を推測したりすることができる.」とあります。表データの関係性を推測できるようなweightは、別の特徴量を予測する際の初期値として優秀、ということなのでしょうか。

5.4 どんなところがTree-basedなの?

TabNetは上記のようなDNNを使ってTree-basedのClassificationを模倣しています。

具体的には二つの特徴量x1,x2に対して、Maskにて特徴量が区分され、それぞれに対しWeight(W)とb(bias)がかかったものをFC(Fully-Connected)とReluを通して、0 or 線形結合した値として使うことで、Tree-basedの決定境界を模倣しているようです。なぜこれでTree-basedの決定境界が模倣できるのかは、よく分かりませんでした😇

5.5 TabNetの全体構造は?

この図には、TabNet全体のアーキテクチャが載っています。その中の図左上に示されているEncoderの部分に着目します。拡大すると下記です。

重要なのは「Step1」「Step2」と記載されている単位で、このStepのたびに、特徴選択を行うこととなります。こちらのSTEP数を多くすれば(実装例だとn_stepで指定)、特徴選択を複数回行う深いモデルにすることができます。ただし、Stepと密接に関わるハイパーパラメータとして、同じ特徴量を何度選択していいか、というものがあります。(実装例だとgammaで指定)。n_step=3でgamma=2の場合、3stepを行う間に、同じ特徴量は2度まで使用して良いことになります。

5.6 Maskとは?

マスクを図示したものです。このマスクは、横軸が使用した各特徴量、縦軸がデータの各行を意味しています。色が白いほど重要な特徴量であることを意味します。つまり、1行1列目が黒い場合は、「データの1行目に対して、1列名の特徴量は重要ではなかった。」という解釈となります。

このmaskは恐らくpretrainの予測を行う際に有用だったかどうかで算出されているように思えるのですが、詳細は分かりませんでした😇

5.7 精度は?

論文には様々な評価が載っていますが、最も分かりやすいのは上記かなと思います。Forest Cover Type datasetという地図上の変数から主な樹木被覆の種類を当てるタスクにおいて、GBDT系よりも高い精度が出せていることが分かります。

6 試してみる

なんだが凄そうなモデルだということは分かったので、実際に試してみましょう。

6.1 実装

実装としてよく使われているのは、こちらのrepositoryかと思います。(公式の実装ではなく有志によるものです。2020/12/6時点)
https://github.com/dreamquark-ai/tabnet

こちらのipynbには使い方の例が載っています。
https://github.com/dreamquark-ai/tabnet/blob/develop/census_example.ipynb

6.2 使用Notebook

再掲ですが、以下の評価を行ったNotebookはこちらにて公開しています。詳細が気になる方は是非見てみてください。

6.3 使用データ

みんな大好きKaggle Titanicのデータを用います。
Titanicのデータは数も少なく、精度の比較に向いていないのは重々承知ですが、身近なデータセットということで今回採用しています。

6.4 TabNetのInstall, Import

このデータセットからから下記のようにインストールします。

!pip install ../input/tabnet/pytorch_tabnet-2.0.1-py3-none-any.whl

その後、下記のようにImportすればOKです。簡単。(今回はTabNetClassifierを使います。)

from pytorch_tabnet.pretraining import TabNetPretrainer

from pytorch_tabnet.tab_model import TabNetRegressor
from pytorch_tabnet.tab_model import TabNetClassifier

6.5 Pretrain

TabNetPretrainerというclassを用意してくれているので、これを使うようです。
実装例は下記。

NNやLGBMの書き方とほぼ同様かと思います。パラメータの意味はrepositoryに説明があります。

tabnet_params = dict(n_d=8, n_a=8, n_steps=3, gamma=1.3,
                     n_independent=2, n_shared=2,
                     seed=SEED, lambda_sparse=1e-3, 
                     optimizer_fn=torch.optim.Adam, 
                     optimizer_params=dict(lr=2e-2),
                     mask_type="entmax",
                     scheduler_params=dict(mode="min",
                                           patience=5,
                                           min_lr=1e-5,
                                           factor=0.9,),
                     scheduler_fn=torch.optim.lr_scheduler.ReduceLROnPlateau,
                     verbose=10
                    )

pretrainer = TabNetPretrainer(**tabnet_params)

pretrainer.fit(
    X_train=train.drop('Survived',axis=1).values,
    eval_set=[train.drop('Survived',axis=1).values],
    max_epochs=200,
    patience=20, batch_size=256, virtual_batch_size=128,
    num_workers=1, drop_last=True)

6.6 Main Training

下記のように行います。NNやLGBMと同様の書きっぷりです。

    tabnet_params = dict(n_d=8, n_a=8, n_steps=3, gamma=1.3,
                         n_independent=2, n_shared=2,
                         seed=SEED, lambda_sparse=1e-3,
                         optimizer_fn=torch.optim.Adam,
                         optimizer_params=dict(lr=2e-2,
                                               weight_decay=1e-5
                                              ),
                         mask_type="entmax",
                         scheduler_params=dict(max_lr=0.05,
                                               steps_per_epoch=int(X_train.shape[0] / 256),
                                               epochs=200,
                                               is_batch_level=True
                                              ),
                         scheduler_fn=torch.optim.lr_scheduler.OneCycleLR,
                         verbose=10,
                         cat_idxs=cat_idxs, # comment out when Unsupervised
                         cat_dims=cat_dims, # comment out when Unsupervised
                         cat_emb_dim=1 # comment out when Unsupervised
                        )

    model = TabNetClassifier(**tabnet_params)

    model.fit(X_train=X_train,
              y_train=y_train,
              eval_set=[(X_valid, y_valid)],
              eval_name = ["valid"],
              eval_metric = ["auc"],
              max_epochs=200,
              patience=20, batch_size=256, virtual_batch_size=128,
              num_workers=0, drop_last=False,
              from_unsupervised=pretrainer # comment out when Unsupervised
             )

Pretrainの結果を反映させるために、fitに下記引数が増えていることに注意してください。
この引数を追加するだけでPretrainが反映できます。

from_unsupervised=pretrainer

6.7 精度

精度の比較を行います。

TabNetについては、Pretrainを行なった場合、Pretrainを行わなかった場合の二つを比較しています。

比較対象のモデルとしては、LightGBMとNNを用いました。
LightGBMは初期パラメータ。NNはシンプルな3層モデルです。

CVはStratified 5-fold。OOFはCVに従って出力。LBはStratified 5-foldで予測した確率の平均をとり、Thresholdを0.5にして作成しています。

CPUを用いて計算しています。

結果はこちら。

- TabNet with Pretrain TabNet without Pretrain LightGBM NN
OOF ROC-AUC 0.8620 0.8354 0.8750 0.8643
OOF Accuracy 0.8114 0.7564 0.8260 0.8249
LB Accuracy 0.7775 0.7560 0.7608 0.7656
Time(s) 34.6 37.3 0.24 6.86

比較すると、下記のようなことが分かります。

  • TabNetは、OOFの精度について、LightGBMやNNよりもやや低い。
  • TabNetは、LBの精度について、LightGBMやNNよりもやや高い。
  • TabNetは、OOFとLBの精度の解離が、LightGBMやNNよりも小さい。
  • TabNet with Pretrainは、TabNet without PretrainよりもOOF,LBの精度が高い。
  • CPUを用いた場合、TabNetは時間がかかる

6.8 Feature Importance(Global interpretability)

TabNetはFeature Importanceが出力できます。(論文でいうglobal interpretability)
学習済モデルをmodelという変数に入れていた場合、下記のように出力することができます。

 model.feature_importances_

実際に出力したところ、下記のようになりました。

比較すると、下記のようなことが分かります。

  • TabNet(with Pretrain or without Pretrain)はLightGBMに比べ、特徴量間の重要度の差が小さい。
  • TabNet with PretrainとLightGBMでは、共にFareが1位にくる。一方、1位以外の特徴量に関しては比較的ばらつきがある。

TabNetの重要度の差がLightGBMよりも小さいことは、TabNetは特徴量の使用回数を制限(gammaで指定)するためLightGBMのように特定の特徴量を繰り返す使うことが出来ないことに起因すると考えられます。

6.9 Mask(Local interpretability)

TabNetはmaskを出力することができます。(論文でいうlocal interpretability)
出力するためのコードはこちら

explain_matrix, masks = model.explain(test[feature_col].values)

fig, axs = plt.subplots(1, 3, figsize=(10,7))

for i in range(3):
    axs[i].imshow(masks[i][:25])
    axs[i].set_title(f"mask {i}")

出力されたMaskはこちら。

このMaskはdecision(どの特徴量を使うか決定)するたびに作成されます。decisionを何回行うかは
n_stepsにより指定できます。今回はn_steps=3を指定していたため、Maskは3枚となりました。

横軸が特徴量(0='Pclass', 1='Sex', 2='Age',3='SibSp', 4='Parch', 5='Fare', 6='Embarked')、縦軸がtestの行数(今回は先頭から25行を抜粋)となります。

ぱっと見、mask2については5='Fare'が強く色づいていることが分かります。これは特徴量重要度にてFareが1位にきたことを考えても、妥当な結果のように感じます。

6.10 予測確率の分布

TabNet with Pretrain, TabNet without Pretrain, LightGBM, NNのtestに予測確率の分布はどんな感じかなと思って、それぞれをhistにしました。下図となります。

比較すると、下記のようなことが分かります。

  • TabNetの出力分布は、NNと近い形状となる。

7 まとめ

調べて分かったことを考察を含めてまとめます。

  • TabNetはMaskと,Feature Importanceを用いることで、解釈性が高いモデルとなっている。
  • Titanicデータにおいて、TabNetの精度はLightBMやNNと同等だった。更にOOFとLBの値の解離が少ないことから、過学習を起こしづらいモデルなのかもしれない。
  • TabNetのFeature Importanceにおける上位特徴量は、LightGBMとは異なっている。そのため、TabNetはEnsembleに有用であることが期待出来る。

8 最後に

Twitterやってますのでfollowしてくれたら喜びます🙇‍♂️
https://twitter.com/sinchir0

Discussion