🔥
【21日目】テーブルデータのデータ拡張をやってみる【2021アドベントカレンダー】
2021年1人アドベントカレンダー(機械学習)、21日目の記事になります。
テーマは TGAN を使ったテーブルデータの生成 になります。
Colab のコードはこちら
TGANによるテーブルデータ生成
目的変数と説明変数をセットにしたデータを学習します。
from tgan.model import TGANModel
train_df = pd.concat([X_train_tf, y_train], axis=1)
continuous_value_columns = ["Rank", "Critic_Score", "User_Score", "Total_Shipped", "Global_Sales", "Year"]
tgan = TGANModel(
continuous_value_columns,
max_epoch=5,
steps_per_epoch=5000,
batch_size=100,
)
tgan.fit(train_df)
学習したモデルをもとに 50,000 行ほどデータを生成します。
num_samples = 50000
train_sample = tgan.sample(num_samples)
print(train_sample.shape)
train_sample.head(3)
(50000, 17)
データをfloat値に変換し、欠損補完をしておきます。
train_sample_fillna = train_sample.astype(float)
train_sample_fillna = train_sample_fillna.fillna(method='ffill')
train_sample_fillna = train_sample_fillna.dropna(axis=0)
train_sample_fillna = train_sample_fillna.reset_index(drop=True)
X_train_sample_fillna, y_train_sample_fillna = train_sample_fillna.drop(["Global_Sales"], axis=1), train_sample_fillna["Global_Sales"]
元の学習データと生成したデータを結合しておきます。
X_train_comb = pd.concat([
X_train_tf[all_columns],
X_train_sample_fillna
]).reset_index(drop=True)
y_train_comb = pd.concat([
y_train,
y_train_sample_fillna
]).reset_index(drop=True)
生成データと結合した学習データをもとに、テストデータの推論を実行します。
GroupKFold による 各Fold ごとの予測値の平均を使います。
gkf = GroupKFold(n_splits=5)
groups = X_train_comb["Genre"]
cv_result_tgan = []
pred_df = pd.DataFrame()
model = lgb.LGBMRegressor(random_state=42)
for i, (train_index, test_index) in enumerate(gkf.split(X_train_comb, y_train_comb, groups)):
X_train_gkf, X_test_gkf = X_train_comb.iloc[train_index], X_train_comb.iloc[test_index]
y_train_gkf, y_test_gkf = y_train_comb.iloc[train_index], y_train_comb.iloc[test_index]
# 学習、推論
model.fit(X_train_gkf, y_train_gkf)
y_pred = model.predict(X_test_gkf)
rmse = mean_squared_error(y_test_gkf, y_pred, squared=False)
cv_result_tgan.append(rmse)
pred = pipe.predict(X_test)
pred_df[i] = pred
rmse_test_tgan = mean_squared_error(y_test, pred_df.mean(axis=1), squared=False)
残念ながら、生成データを使ったデータ拡張を行った方が精度が下がってしまいました。
項目 | RMSE |
---|---|
学習データのみを学習・テストデータを推論 | 0.239 |
T-GAN による生成データを含め学習・テストデータを推論 | 0.290 |
21日目は以上になります、最後までお読みいただきありがとうございました。
参考サイト
GPU使用にあたり参考にしました。
Discussion