🎉

拡散モデルと表データ生成③:【論文】STaSy

2024/03/17に公開

STaSy: Score-based Tabular Data Synthesis (ICLR2023)

TabDDPM, CoDiに続いてSTaSyの論文を読んだので, そのまとめになります. まとめと言いつつかなり長くなっています. 図や表は断りのない限り論文からの引用です. これまではICMLの論文でしたが, 今回はICLRの論文です.

arXiv
https://arxiv.org/abs/2210.04018

OpenReview
https://openreview.net/forum?id=1mNssCWt_v

書籍情報

Kim, J., Lee, C., and Park, N. STasy: Score-based tabular data synthesis. In The Eleventh International Conference on Learning Representations, 2023.

関連情報

TL;DR

  • スコアベースの表データ生成モデルの提案
  • self-paced trainingとfine-tuningを経ることで訓練を安定化させ, 高品質生成を実現
  • 15のデータセットと7つのベースラインを用いた比較ではqualityとdiversityで既存手法を凌駕

まず, 実験のまとめを見てみます. 全データセットの平均を報告しています. これを見ると, 確かに高品質で多様性のあるデータ生成がされていることがわかります (CoDiのときと同様speedはイマイチです).

提案手法は3つの大きな特徴があり, それらを順番に見ていきます.

  1. スコアベースのモデル
  2. self-paced learning
  3. fine-tuning approach

スコアベースのモデル

まず, スコアベースのモデルを表データに適用する前に, 拡散モデルのスコアベースとしての見方をしておきます. 以下の論文などが参考になります.

https://openreview.net/forum?id=PxTIG12RRHS

DDPMをベースとした定式化は, 以下のIto SDEに従います.

d\boldsymbol{x}=\boldsymbol{f}(\boldsymbol{x}, t)dt+g(t)d\boldsymbol{w}

ここで, DDPMの場合には \boldsymbol{f}(\boldsymbol{x}, t)=-\dfrac{1}{2}\beta(t)\boldsymbol{x}, g(t)=\sqrt{\beta(t)} です. より一般の場合 (ガウシアンノイズの平均と分散の両方がパラメータ化される場合)では, \boldsymbol{f}(\boldsymbol{x}, t)=f(t)\boldsymbol{x} です.

fg によってスコアベースの生成モデル (SGM)はVE, VP, sub-VPの3つに分類されます. それに従うと, 逆拡散過程は以下のように表せます.

d\boldsymbol{x}=(\boldsymbol{f}(\boldsymbol{x}, t)-g^2(t)\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x}))dt+g(t)d\boldsymbol{w}

ここで, \nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x}) がスコアと呼ばれる量で, 学習済みの拡散モデルで定数倍の違いを除いて推定可能です. S_{\theta}(\boldsymbol{x}, t) でスコアを近似します. この S_\theta をスコアネットワークと呼ぶことにします.

今回は最初からスコアネットワークを学習することを考えます. すると, 以下の式を満たすネットワークを求めることになります.

\argmin_{\theta}\mathbb{E}_{t, \boldsymbol{x}(t), \boldsymbol{x}(0)}\left[\lambda(t)\|S_{\theta}(\boldsymbol{x}, t)-\nabla_{\boldsymbol{x}(t)}\log p(\boldsymbol{x}(t)\mid\boldsymbol{x}(0))\|_2^2\right]

\lambda(t) は生成品質と尤度のトレードオフを制御します.

さて, これを表データ生成に適用することを考えます. 表データの特徴は以下の3つです.

  • 表データは複雑な分布をしていて, 難易度が高い
  • 一方で次元数は非常に低い (例えばMNISTでは784ピクセルですが, 実験で用いているデータの1つであるCreditは30カラム程度しかありません)
  • カラム間の結合確率がある (カラム同士の関係性が決まっています)

スコアネットワークが十分に学習されている場合に, この壁を越えられるようです.

論文にモデルの概要図などはありませんが, スコアネットワークの設計が示されているので見てみます. ネットワークは全結合層の残差接続で構成されています.

\begin{align*} \boldsymbol{h}_0&=\boldsymbol{x}(t) \\ \boldsymbol{h}_i&=\omega(\texttt{H}_i(\boldsymbol{h}_{i-1}, t)\oplus\boldsymbol{h}_{i-1})\quad 1\leq i\leq d_N \\ S_{\theta}(\boldsymbol{x}, t)&=\texttt{FC}(\boldsymbol{h}_{d_N}) \end{align*}

\omega は活性化関数で, d_N は隠れ層の数です. 各レイヤーは以下のような構造になっています. これは, レイヤーの選択肢に自由度があるというだけです.

\texttt{H}_i(\boldsymbol{h}_{i-1}, t)= \begin{cases} \texttt{FC}_i(\boldsymbol{h}_{i-1})\oplus\psi(\texttt{FC}_i^t(t)) & \mathrm{if\ Squash} \\ \texttt{FC}_i(t\oplus\boldsymbol{h}_{i-1}) & \mathrm{if \ Concat} \\ \texttt{FC}_i(\boldsymbol{h}_{i-1})\oplus\psi(\texttt{FC}_i^{gate}(t)+\texttt{FC}_i^{bias}(t)) & \mathrm{if\ Concatsquash} \end{cases}

\psi はシグモイド関数です.

入力は, 数値データをmin-max scaler, カテゴリデータをone-hot encodingで前処理しています.

self-paced learning

self-paced learning (SPL)を用いて事前学習を行います. SPLとは, カリキュラム学習に関連する訓練戦略のひとつです. カリキュラム学習とは, 学習中のモデルに与えるデータを制御することで効率的に訓練を行う手法で, 例えば簡単なデータを最初に与えて徐々に難易度を上げるなどがあります. モデルは以下の目的関数を最小化します.

\min_{\theta, \boldsymbol{v}}\mathbb{E}(\theta, \boldsymbol{v})=\sum_{i=1}^Nv_iL(M(\boldsymbol{x}_i, \theta))-\frac{1}{K}\sum_{i=1}^Nv_i

\boldsymbol{v}=[v_i]_{i=1}^N, v_i\in\{0, 1\} は, データ \boldsymbol{x}_i が簡単かどうかを表すラベルです. L は損失, M はモデルです. 第2項は正則化項で, 下流タスクに応じて変更可能で, K は学習ペースを制御するパラメータです.

i 番目のデータ \boldsymbol{x}_i に対する損失を l_i とすると,

l_i=\mathbb{E}_{t, \boldsymbol{x}_i(t)}\left[\lambda(t)\|S_{\theta}(\boldsymbol{x}_i, t)-\nabla_{\boldsymbol{x}_i(t)}\log p(\boldsymbol{x}_i(t)\mid\boldsymbol{x}_i(0))\|_2^2\right]

です. そのため, STaSyの目的関数は

\min_{\boldsymbol{\theta}, \boldsymbol{v}}\sum_{i=1}^Nv_il_i+r(\boldsymbol{v}; \alpha, \beta)

です. \alpha, \beta は0以上1以下で, 訓練が進むにつれて単調増加します. なお, 初期値を \alpha_0, \beta_0 とすると以下の式に従って変化します.

\begin{align*} &\alpha=\alpha_{0}+\log\left(1+c\left(\dfrac{e-1}{S}\right)(1-\alpha_0)\right) \\ &\beta=\beta_{0}+\log\left(1+c\left(\dfrac{e-1}{S}\right)(1-\beta_0)\right) \end{align*}

e はネイピア数, c はその時点のtraining step, S は全てのデータを使うタイミングを決定し, 実験では S=10000 を採用しています.

正規化関数 r を具体的に定めていきます. そのために記号の定義を行います.

Fをdenoising score matchingの目的関数のCDF (累積分布関数)としたときに, 関数 Q(p) を, \inf \{l\in\mathbb{R}: p\leq F(l)\} で定めます. すなわち, Q(p) は与えられた確率 p よりCDFが大きいか等しい最小値です.

このとき, 正規化関数を

r(\boldsymbol{v}; \alpha, \beta)=-\dfrac{Q(\alpha)-Q(\beta)}{2}\sum_{i=1}^Nv_i^2-Q(\beta)\sum_{i=1}^Nv_i

とします. このとき, \boldsymbol{v}^*=[v_1^*,\ldots,v_N^*] は固定パラメータ \theta が与えられたときに以下の閉形式で書けます.

v_i^*=\begin{cases} 1 & \mathrm{if\ } l_i\leq Q(\alpha) \\ 0 & \mathrm{if\ } l_i\geq Q(\beta) \\ \dfrac{l_i-Q(\beta)}{Q(\alpha)-Q(\beta)} & \mathrm{Otherwise} \end{cases}

証明はここでは省略しますが, Appendix Eに記載されています (\mathcal{L}(v_i)v_i に関して2次関数になります).

この式を具体的にみます. v_i というのはデータ \boldsymbol{x}_i が簡単かどうかを示すラベルでした. すなわち, 損失 l_iQ(\alpha) より小さいときは簡単なデータ, Q(\beta) より大きいときは難しいデータとみなされます. それ以外の場合, そのデータは部分的に使用されることを示します. 触れていませんでしたが, \alpha\leq\beta の関係を保証しますので, l_i-Q(\beta)<0, Q(\alpha)-Q(\beta)<0 です. そのため, 区間 [0, 1] に含まれます.

Fine-Tuning Approach

self-paced learningによる学習が終わったら, パラメータを微調整するフェーズに入ります.

逆拡散過程のSDEによるアプローチとして, さまざまな数値解法があります. そのうちの1つにprobabilistic flowがあります. そこでは次のNODE (Neural ODE)を用います.

d\boldsymbol{x}=\left(\boldsymbol{f}(\boldsymbol{x}, t)-\dfrac{1}{2}g(t)^2\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})\right)dt

NODEは対数確率を計算するのに有用なので, 正確な対数確率に基づいたfine-tuningを行います.

パラメータ \theta を学習後, サンプルごとの閾値 \tau_i\log p(\boldsymbol{x}_i) で設定します. その後, fine-tuningするデータを選びます. ここでは \mathcal{F}=\{\boldsymbol{x}_i\mid\log p(\boldsymbol{x}_i), \boldsymbol{x}_i\in\mathcal{D}\} とします. ここで, \boldsymbol{x}_i は対数確率 \log p(\boldsymbol{x}_i) が平均あるいは中央値以下であるという制約を課します. 各エポックで l_i によってパラメータを更新したら, \mathcal{F}\leftarrow\{\boldsymbol{x}_i\mid\log p(\boldsymbol{x}_i)<\tau_i\} によってデータセットを更新します.

実験

15のデータと7つのベースラインを用いて実験を行います. 今回は, 普通にスコアベースのモデルを訓練したNaive-STaSyとself-paced learningとfine-tuningを行うSTaSyを提案手法として区別しています.

評価方法としてTSTR frameworkを用います. 生成品質については分類データに対してはaverage F1を主に使い, 補助的に AUROCとWeighted-F1を用います. 回帰データに対しては R^2 とRMSEを用います. 生成データの多様性については, coverageを用います.

それぞれ結果を見てみます.

Sampling Quality

MedGANとVEEGANは2017年のモデルでかなり初期のもののため, 生成品質は非常に低いです. また, ベースラインの中ではCTGANとOCT-GANが高性能ですが, 今回提案された手法の方が良い結果になっています.

個別のデータについて少し見ていきます. まず, 多クラス分類のデータです.

Crowdsource、Obesity、Robotといったデータでは他の手法を圧倒しています. このことについて, クラスごとのim-balancedが原因ではないかと考察されています.

このことは, 2値分類のCreditでも見られます. このデータは 99.7:0.3 の非常に不均衡なデータのため, 多くのベースラインが不均衡性を再現できず, F1が0に近いものになっていることがわかります. しかし, どの場合でもSTaSyが最高性能を誇っており, 手法の有用性が確認できます.

Sampling Diversity

続いて, coverageを用いた多様性の比較をします.

表からもわかるように, STaSyは非常に優れた結果を示しています. Robotデータをt-SNEで可視化した図を示します. 赤で囲まれた部分は特に実データとの分布の差が大きいことがわかります.

続いてBeanデータのRoundness (左)とCompactness (右)カラムのヒストグラムをみます. 多様性のために提案されたOCT-GANですが, 実データの分布を大きく外していることがわかります. それに対して提案手法はより詳細に分布をとらえていることがわかります.

Sampling Time

続いて, 生成速度を比較します. 訓練データと同じ量生成するために要した時間を比較しています. 全てのデータの平均を示していますが, 個々のデータで訓練データのサイズに違いがあるのであまり比較できないような気がします. 例えば, Creditは264.8Kですが, Contraceptiveは1.2Kです. 著者らはquality, diversity, speedのトレードオフがバランスよく取れていると主張していますが, それはどうなのでしょうか. TVAEと比較してみると100倍の時間がかかっていますがそれに見合った品質と多様性かは主観的な判断な気がします.

Ablation Study

ここでは, fine-tuning&SPLなし, fine-tuningなし, SPLなし, STaSyとの比較をして手法の要素が有用であることを示します.

概ね要素を追加していくと品質が上昇していることがわかります. 特にSPLは多様性に貢献します. Beijingの生成例を可視化すると, SPLを行ったことで多様性が生まれていることが確認できます. 特に, SPL無しではややモード崩壊が見られます (赤で囲まれた部分). 確かに, 図ではSPLなしでモード崩壊が確認できます. しかし, これがcherry pickingであることを否定できないのでcoverageで出すべきだと思いました.

Sensitive Study

提案手法では一般的なハイパーパラメータ以外にも, 様々な設定要素がありました. 例えば \alpha_0, \beta_0 であったり, SDE Solverの選択の自由度です.

いくつかのパラメータを変更して実験した結果が以下の表です. 基本的にどの設定でも性能が高いです. 著者らは \alpha_{0} には0.2または0.25を, \beta_{0} には0.9または0.95を推奨しています.

まとめ

  • スコアベースの表データ生成モデルの提案
  • 提案手法は3つの大きな特徴がある
    1. スコアベースのモデル
    2. self-paced learning
    3. fine-tuning approach
  • 15のデータセットと7つのベースラインを用いた比較ではqualityとdiversityで既存手法を凌駕

思ったこと

  • 最新のスコアベースの手法を取り入れていて, かつパフォーマンスも良い
  • 定量的指標がに加えて可視化された図も多く, 説得力がある
  • OpenReviewの査読に対する反論が全くないのが気になる
  • CoDiの際にも思ったことであるが, speedがよくないのでgenerative trilemmaを持ち出すのはやめた方がいいように思う (査読者の1人もruntimeがmajor concernと述べています)

参考文献

Discussion