🎰

TabNetのTuningの勘どころ

2022/02/27に公開

はじめに

某コンペに参加したときにTabNetのパラメーターのチューニングの勘どころを探ってみました。あんまりよい文献がなかったので、突っ込みあったらください^q^


結論(暫定)

パラメータ 探索範囲 デフォルト
n_d, n_a 8-64 8
n_steps 1-10 3
gamma 1.0-2.0 1.3
mask_type "entmatx" or "sparsemax" "sparsemax"

各変数の説明は実装リポジトリ日本語説明記事を読んでください。

その他、batch_sizeはメモリ使用量に合わせてデータの1〜10%に上げることが推奨されている。

基本的に後述する元論文をそのまま持ってきていますが、一部Kaggleでよくいじられているパラメーター等を用いています。Kaggle探索時のハイパラのメモは以下。

  • とりあえずn_d, n_aは上げてることが多い。
  • n_stepsは公式論文では3〜10が推奨されているが、下げられていることが多い。
  • gamma, n_independent, n_shared等もいじられていることもある。この辺は計算範囲と相談っぽい。

元論文におけるハイパラに関するアドバイス

元論文に、ハイパーパラメータについて以下記述があります。

We consider datasets ranging from ∼10K to ∼10M samples, with varying degrees of fitting difficulty. TabNet obtains high performance on all with a few general principles on hyperparameters:
•For most datasets, Nsteps ∈ [3, 10] is optimal. Typically, when there are more information-bearing features, the optimal value of Nsteps tends to be higher. On the other hand, increasing it beyond some value may adversely affect training dynamics as some paths in the network becomes deeper and there are more potentially-problematic ill-conditioned matrices. A very high value of Nsteps may suffer from overfitting and yield poor generalization.
•Adjustment of Nd and Na is an efficient way of obtaining a trade-off between performance and complexity. Nd = Na is a reasonable choice for most datasets. A very high value of Nd and Na may suffer from overfitting and yield poor generalization.
•An optimal choice of γ can have a major role on the performance. Typically a larger Nsteps value favors for a larger γ.
•A large batch size is beneficial – if the memory constraints permit, as large as 1-10 % of the total training dataset size can help performance. The virtual batch size is typically much smaller.
•Initially large learning rate is important, which should be gradually decayed until convergence.

  • データセットは10K〜10M(1万〜1000万)くらいで検証した。
  • Nstepsは3〜10くらいがよい
  • NdとNaは同じ値の方がよい
  • Nstepsとγは相関させよ
  • batch sizeは大きい方がよい。データの1〜10%くらいあるとよい(これはメモリとの相談なので別問題)。
  • learning rateは最初は大きく、だんだん減衰する方がよい

optuna用関数

import torch
import pandas as pd
from pytorch_tabnet.tab_model import TabNetRegressor

def objective(trial):
    mask_type = trial.suggest_categorical("mask_type", ["entmax", "sparsemax"])
    n_da = trial.suggest_int("n_da", 8, 64, step=8)
    n_steps = trial.suggest_int("n_steps", 1, 10, step=3)
    gamma = trial.suggest_float("gamma", 1.0, 2.0, step=0.2)
    n_shared = trial.suggest_int("n_shared", 1, 3)
    lambda_sparse = trial.suggest_float("lambda_sparse", 1e-6, 1e-3, log=True)

    tabnet_params = dict(
        n_d=n_da, n_a=n_da, n_steps=n_steps, gamma=gamma,
        lambda_sparse=lambda_sparse, mask_type=mask_type, n_shared=n_shared,
        optimizer_fn=torch.optim.Adam,
        optimizer_params=dict(lr=2e-2, weight_decay=1e-5),
        scheduler_params=dict(
            mode="min", patience=10,
            scheduler_fn=torch.optim.lr_scheduler.ReduceLROnPlateau,
            verbose=0,
        )
    )

    model = TabNetRegressor(**tabnet_params)
    model.fit(
        X_train=trn_x, y_train=trn_y,
        eval_set=[(val_x, val_y)],
        patience=15,
        max_epochs=100,
        eval_metric=['rmse']
    )

参考

  • Tabnetのgeneralな説明

https://www.guruguru.science/competitions/16/discussions/70f25f95-4dcc-4733-9f9e-f7bc6472d7c0/

  • KaggleにおけるOptuna実装

https://www.kaggle.com/neilgibbons/tuning-tabnet-with-optuna

  • Pytorch版Tabnet

https://github.com/dreamquark-ai/tabnet

  • 元論文

https://arxiv.org/abs/1908.07442

  • Kaggleのコンペで使われている例。(MoAコンペの1st solution)

PyTorch TabNet regressor training setup: a width of 32 for the decision prediction layer, and a width of 32 for the attention embedding for each mask, 1 step in the architecture, a gamma value of 0.9, Adam optimizer with a learning rate of 2e-2 and a weight decay of 1e-5, a sparsity loss coefficient of 0, and entmax as the masking function. It was trained with a batch size of 1024 and a virtual batch size of 128 for 200 epochs before early-stopped by a patience of 50 epochs. The final TabNet model produced a CV log loss of 0.01615.

https://www.kaggle.com/c/lish-moa/discussion/201510

Discussion