🔥

【21日目】テーブルデータのデータ拡張をやってみる【2021アドベントカレンダー】

2021/12/21に公開

2021年1人アドベントカレンダー(機械学習)、21日目の記事になります。

https://qiita.com/advent-calendar/2021/solo_advent_calendar

テーマは TGAN を使ったテーブルデータの生成 になります。

Colab のコードはこちら Open In Colab

https://sdv.dev/TGAN/index.html

TGANによるテーブルデータ生成

目的変数と説明変数をセットにしたデータを学習します。

from tgan.model import TGANModel

train_df = pd.concat([X_train_tf, y_train], axis=1)

continuous_value_columns = ["Rank", "Critic_Score", "User_Score", "Total_Shipped", "Global_Sales", "Year"]

tgan = TGANModel(
    continuous_value_columns,
    max_epoch=5,
    steps_per_epoch=5000,
    batch_size=100,
    )

tgan.fit(train_df)

学習したモデルをもとに 50,000 行ほどデータを生成します。

num_samples = 50000

train_sample = tgan.sample(num_samples)

print(train_sample.shape)

train_sample.head(3)

(50000, 17)

データをfloat値に変換し、欠損補完をしておきます。

train_sample_fillna = train_sample.astype(float)

train_sample_fillna = train_sample_fillna.fillna(method='ffill')

train_sample_fillna = train_sample_fillna.dropna(axis=0)

train_sample_fillna = train_sample_fillna.reset_index(drop=True)

X_train_sample_fillna,  y_train_sample_fillna = train_sample_fillna.drop(["Global_Sales"], axis=1), train_sample_fillna["Global_Sales"]

元の学習データと生成したデータを結合しておきます。

X_train_comb = pd.concat([
           X_train_tf[all_columns],
           X_train_sample_fillna
]).reset_index(drop=True)

y_train_comb = pd.concat([
           y_train,
           y_train_sample_fillna
]).reset_index(drop=True)

生成データと結合した学習データをもとに、テストデータの推論を実行します。

GroupKFold による 各Fold ごとの予測値の平均を使います。

gkf = GroupKFold(n_splits=5)

groups = X_train_comb["Genre"]

cv_result_tgan = []

pred_df = pd.DataFrame()

model = lgb.LGBMRegressor(random_state=42)

for i, (train_index, test_index) in enumerate(gkf.split(X_train_comb, y_train_comb, groups)):
    X_train_gkf, X_test_gkf = X_train_comb.iloc[train_index], X_train_comb.iloc[test_index]
    y_train_gkf, y_test_gkf = y_train_comb.iloc[train_index], y_train_comb.iloc[test_index]

    # 学習、推論
    model.fit(X_train_gkf, y_train_gkf)

    y_pred = model.predict(X_test_gkf)

    rmse = mean_squared_error(y_test_gkf, y_pred, squared=False)
    cv_result_tgan.append(rmse)

    pred = pipe.predict(X_test)

    pred_df[i] = pred

rmse_test_tgan = mean_squared_error(y_test, pred_df.mean(axis=1), squared=False)

残念ながら、生成データを使ったデータ拡張を行った方が精度が下がってしまいました。

項目 RMSE
学習データのみを学習・テストデータを推論 0.239
T-GAN による生成データを含め学習・テストデータを推論 0.290

21日目は以上になります、最後までお読みいただきありがとうございました。

参考サイト

GPU使用にあたり参考にしました。
https://github.com/sdv-dev/TGAN/issues/34

Discussion