🥅

【17日目】Tabnet をやってみる【2021アドベントカレンダー】

2021/12/17に公開

2021年1人アドベントカレンダー(機械学習)、17日目の記事になります。

https://qiita.com/advent-calendar/2021/solo_advent_calendar

テーマは TabNet になります。

TabNet は テーブルデータ用のディープラーニングモデル です。

詳細は下記サイト参照。

https://zenn.dev/sinchir0/articles/9228eccebfbf579bfdf4

Colab のコードはこちら Open In Colab

TabNet

import torch
from pytorch_tabnet.tab_model import TabNetRegressor
model = TabNetRegressor(
                seed=SEED,
                verbose=0,
                device_name = "cuda",
)

# 学習・推論
gkf = GroupKFold(n_splits=5)

groups = X_train_tf[:, 0]

cv_result_tbnt = []

for i, (train_index, test_index) in enumerate(gkf.split(X_train_tf, y_train, groups)):
    X_train_gkf, X_test_gkf = X_train_tf[train_index], X_train_tf[test_index]
    y_train_gkf, y_test_gkf = y_train.iloc[train_index], y_train.iloc[test_index]

    model.fit(
                X_train=X_train_gkf,
                y_train=y_train_gkf.values.reshape(-1, 1),
                eval_set=[(X_train_gkf, y_train_gkf.values.reshape(-1, 1)), (X_test_gkf, y_test_gkf.values.reshape(-1, 1))],
                eval_name = ["train", "valid"],
                eval_metric = ["rmse"],
                max_epochs=200,
                patience=20, 
                batch_size=256, 
                virtual_batch_size=128,
                num_workers=0, 
                drop_last=False,
    )

    # 損失推移
    plt.title(f"Fold {i}")
    plt.plot(model.history['train_rmse'], label="train rmse")
    plt.plot(model.history['valid_rmse'], label="valid rmse")
    plt.xlabel("epoch")
    plt.ylabel("rmse")
    plt.legend()
    plt.show()

    # 推論
    y_pred = model.predict(X_test_gkf)

    # 評価
    rmse = mean_squared_error(y_test_gkf, y_pred, squared=False)
    cv_result_tbnt.append(rmse)

print("RMSE:", cv_result_tbnt)
print("RMSE:", np.mean(cv_result_tbnt))

下記により損失推移を可視化することが可能です。

# 損失推移
plt.title(f"Fold {i}")
plt.plot(model.history['train_rmse'], label="train rmse")
plt.plot(model.history['valid_rmse'], label="valid rmse")
plt.xlabel("epoch")
plt.ylabel("rmse")
plt.legend()
plt.show()

精度は以下のとおりです。

項目 RMSE
LightGBM(ハイパラなし) 0.192
TabNet 0.162

可視化

TabNet は可視化も可能です。

OneHot後のカラムを付けたDataFrameを作成

### Nullのみで削除されるカラムを削除
for column in ["VGChartz_Score", "Total_Shipped"]:
    number_columns.remove(column)

# OneHotを考慮したの全カラム作成
all_columns = number_columns + \
                        many_kinds_category_columns + \
                        pipe["columns_transformers"].transformers_[2][1]["onehot"].get_feature_names(few_kinds_category_columns).tolist()

X_train_tf_pd = pd.DataFrame(
                X_train_tf,
                columns=all_columns
            )

Feature Importance

feat_imp = pd.DataFrame(model.feature_importances_, index=X_train_tf_pd.columns)
feature_importance = feat_imp.copy()

feature_importance["imp_mean"] = feature_importance.mean(axis=1)
feature_importance = feature_importance.sort_values("imp_mean")

plt.figure(figsize=(12, 8))
plt.barh(feature_importance.index.values, feature_importance["imp_mean"])
plt.title("feature_importance", fontsize=18)

mask の可視化

from matplotlib import ticker

explain_matrix, masks = model.explain(X_train_tf)

fig, axs = plt.subplots(1, 3, figsize=(10, 7))

for i in range(3):
    axs[i].imshow(masks[i][:25])
    axs[i].set_title(f"mask {i}")
    axs[i].set_xticklabels([0] + X_train_tf_pd.columns.tolist(), rotation=90)
    axs[i].xaxis.set_major_locator(ticker.MultipleLocator(1))

https://qiita.com/DS27/items/ec63ac776c5411836405

17日目は以上になります、最後までお読みいただきありがとうございました。

https://dreamquark-ai.github.io/tabnet/index.html

https://github.com/dreamquark-ai/tabnet

https://data-analysis-stats.jp/深属学習/tabnet(表形式データ向けの深層学習)/

Discussion