TabDiff: A Multi-Modal Diffusion Model for Tabular Data Generation
久しぶりに拡散モデルで表データを生成する論文を見つけたのでまとめておきます (調べてみると, 他にもあるようです). なお, タイトルにmulti-modalとありますがこれは数値とカテゴリのことのようで, 画像やテキストが生成できるわけではないです (表データではmixed-typeというフレーズがこのことを指しますが, この論文では使われていないです).
関連リンク
公式実装は 今(2024/11/13)のところコードの部分はなさそうです.
書籍情報
断りのない限り, 図表は以下の引用となります.
Juntong Shi, Minkai Xu, Harper Hua, Hengrui Zhang, Stefano Ermon, and Jure Leskovec. Tabdiff: a multi-modal diffusion model for tabular data generation, 2024.
導入
どの論文でもそうですが, 表データ生成にとって最大の壁は数値データとカテゴリデータの2つがあるいわゆるmixed-typeなことであると思われます. この論文ではmultimodalと呼んでいますが内容としてはほぼ同じです. もうひとつにデータのtype (型)が同じでも性質が全く異なるとこの論文では表データの課題を示しています.
既存の表データ生成でもこの課題に向き合ってきたわけですが, 著者らは
- additional encoding overhead
- imperfect discrete-time diffusion modeling
- none of them consider the feature-wise distribution heterogeneity issue in a multi-modal framework
という3点によって効果が副次的であると述べています.
それらを解決するのが提案手法であるTabDiffです.
- 表現力のある連続時間の拡散モデルフレームワークにおいて元データ空間のjoint distributionを学習する
- 特徴量の周辺分布に敏感で特徴固有の情報とpair-wise相関を適応的に推論することが可能
という2つのメリットがあるようです. 具体的に手法と実験を通して確認していきます.
手法
まず, Notationを先に定義しておきます.
Notation |
意味 |
\mathcal{T} |
mixed-typeな表データ |
M_{\mathrm{num}},\ M_{\mathrm{cat}} |
数値データとカテゴリデータの数 |
\bold{x}^{\mathrm{num}},\ \bold{x}^{\mathrm{cat}} |
数値データとカテゴリデータ |
\bold{x}=[\bold{x}^{\mathrm{num}}, \bold{x}^{\mathrm{cat}}] |
各データ |
\left(\bold{x}^{\mathrm{num}}\right)_i\in\mathbb{R} |
i 番目の数値データ |
C_j |
カテゴリカラム |
\left(\bold{x}^{\mathrm{cat}}\right)_j\in\{0, 1\}^{(C_j+1)} |
i 番目のカテゴリデータ |
[MASK] |
余分な次元におけるデータ |
\mathrm{Cat}(\dot;\boldsymbol{\pi}) |
カテゴリ分布 |
\pi\in\Delta^K |
K クラスあるときの確率 |
\Delta^K |
K-simplex |
ここからわかるように, カテゴリデータはone-hotで表します.
さて, 表データはmultimodalなデータで, joint distribution p(\bold{x}) を学習する必要があります. 提案手法であるTabDiffはそれを可能にしたとされる手法です. 概要図を示します.
特徴的なこととして, サンプリングなどで用いる時刻 t が連続時間であることが挙げられます. 以降では
の順序でそれぞれを確認します.
Multimodal Diffusion Framework
まず, 拡散モデルでモデリングを行う際に大事なことがあります. それはデータ空間で行うか潜在空間で行うか, ということです. これまでに見た拡散モデルの表データ生成手法ではTabSynを除いて全てデータ空間でモデリングされています. TabDiffでもそれを踏襲し, データ空間でのモデリングを行います. そして, 基本的な方針としてはTabDDPMと同じような定式化を行い, 数値データとカテゴリデータで別々にノイズ付与などをします. ここで, TabDDPMではスケジューリングが同じであったのに対し, TabDiffでは異なるスケジューリング \bold{\sigma}^{\mathrm{num}},\ \bold{\sigma}^{\mathrm{cat}} を用います. 時刻 t\in[0, 1] におけるデータを \{\bold{x}_t:t\sim[0, 1]\} として表します. なお, \bold{x}_0\sim p_0 は日データからのi.i.d.サンプル, \bold{x}_1\sim p_1 はピュアなノイズです. 2つのデータを統合した場合をhybridなケースとすると, hybrid forward processは
q(\bold{x}_t\mid\bold{x}_0)=q(\bold{x}_t^{\mathrm{num}}\mid\bold{x}_0^{\mathrm{num}}, \bold{\sigma}^{\mathrm{num}}(t))\cdot q(\bold{x}_t^{\mathrm{cat}}\mid\bold{x}_0^{\mathrm{cat}}, \bold{\sigma}^{\mathrm{cat}}(t))
で表すことができます. reverse processは
q(\bold{x}_s\mid\bold{x}_t, \bold{x}_0)=q(\bold{x}_s^{\mathrm{num}}\mid\bold{x}_t, \bold{x}_0)\cdot q(\bold{x}_s^{\mathrm{cat}}\mid\bold{x}_t, \bold{x}_0)
です. ここで, s, t は 0<s<t<1 を満たす任意のtimestepです. 目標はdenoising modelである p_{\theta}(\bold{x}_s\mid\bold{x}_t) を学習することです. とは言ってもこの定式化ではまだ具体的に学習することはできないのでさらに細かく見ていきます.
数値データ
数値データ \bold{x}^{\mathrm{num}} は確率微分方程式として学習します. 特に, Ito SDEの形で表され, d\bold{x}=\bold{f}(\bold{x}, t)dt+g(t)d\bold{w} です. これは拡散モデルのtimestepsを無限大まで大きくした場合, reverse processが確率微分方程式に従い, それが先ほどの式ということです. これは確率微分方程式 (SDE)ですが, 実はprobability flow ODEという常微分方程式の解と一致することが知られています. そのことについては以下の論文で詳しく説明されています.
https://openreview.net/forum?id=k7FuTOWMOc7
https://arxiv.org/abs/2011.13456
そのprobabilty flow ODEですが,
d\bold{x}=\left[\bold{f}(\bold{x}, t)-\dfrac{1}{2}g(t)^2\nabla_{\bold{x}}\log p_t(\bold{x})\right]dt
と書けます. ここで, \nabla_{\bold{x}}\log p_t(\bold{x}) という量はスコアと呼ばれるもので, これは学習済みの拡散モデルで計算することができます. 先ほどから出てきている \bold{f} は時刻 t と状態 \bold{x} に依存する確率過程の平均変化率を表しますが, 基本的には簡単のために \bold{0} とすることが多いです. TabDiffでも同様の扱いをします. また, g(t)=\sqrt{2\left[\dfrac{d}{dt}\bold{\sigma}^{\mathrm{num}}(t)\right]\bold{\sigma}^{\mathrm{num}}(t)} とします (VE-SDEの形です). するとforward processは
\bold{x}_t^{\mathrm{num}}=\bold{x}_0^{\mathrm{num}}+\bold{\sigma}^{\mathrm{num}}(t)\boldsymbol{\varepsilon},\quad\bold{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I}_{M_{\mathrm{num}}})
で, reverse processは
d\bold{x}^{\mathrm{num}}=-\left[\dfrac{d}{dt}\bold{\sigma}^{\mathrm{num}}(t)\right]\bold{\sigma}^{\mathrm{num}}(t)\nabla_{\bold{x}}\log p_t(\bold{x}^{\mathrm{num}})dt
です. TabDiffでは, 拡散モデル \boldsymbol{\mu}_{\theta} をjointでdenoiseします. \boldsymbol{\mu}_{\theta}^{\mathrm{num}} は数値部分のdenoiseを表し, 以下のlossで訓練します.
\mathcal{L}_{\mathrm{num}}(\theta, \rho)=\mathbb{E}_{\bold{x}_0\sim p(\bold{x}_0)}\mathbb{E}_{t\sim U[0, 1]}\mathbb{E}_{\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0},\boldsymbol{I})}\|\boldsymbol{\mu}_{\theta}^{\mathrm{num}}(\bold{x}_t, t)-\boldsymbol{\varepsilon}\|_2^2
カテゴリデータ
カテゴリデータは離散空間で表現されます. ここでは言語モデリングでの手法に想起されたものを用います. それはMasked Diffusion Modelです. forward processはmaskign processと解することができます.
q(\bold{x}_t\mid\bold{x}_0)=\mathrm{Cat}(\bold{x}_t;\alpha_t\bold{x}_0+(1-\alpha_t)\bold{m})
\alpha_t\in[0, 1] は, t に関して単調減少の関数で, \alpha_0\approx1, \alpha_1\approx0 です. これはデータ \bold{x}_0 が時刻 t においてmaskされる確率を表します. 実装上は \alpha_t=\exp(-\bold{\sigma}^{\mathrm{cat}}(t)) でパラメータ化します. \bold{\sigma}^{\mathrm{cat}}(t) は単調増加な関数です. \alpha_{t|s}=\alpha_t/\alpha_s と書くことにすれば,
q(\bold{x}_t\mid\bold{x}_s)=\mathrm{Cat}(\bold{x}_t;\alpha_{t|s}\bold{x}_s+(1-\alpha_{t|s})\bold{m})
です. reverse processは数値データと同様になりますが, 定式化はやや異なります.
q(\bold{x}_s\mid\bold{x}_t,\bold{x}_0)=\begin{cases}
\mathrm{Cat}(\bold{x}_s;\bold{x}_t) & \bold{x}_t\neq\bold{m} \\
\mathrm{Cat}\left(\bold{x}_s;\dfrac{(1-\alpha_s)\bold{m}+(\alpha_s-\alpha_t)\bold{x}_0}{1-\alpha_t}\right) & \bold{x}_t=\bold{m}
\end{cases}
です. 数値データと同様に \boldsymbol{\mu}_{\theta}^{\mathrm{cat}}:C\times[0, 1]\rightarrow\Delta^C とすると,
p_{\theta}(\bold{x}_s^{\mathrm{cat}}\mid\bold{x}_t^{\mathrm{cat}})=\begin{cases}
\mathrm{Cat}(\bold{x}_s^{\mathrm{cat}};\bold{x}_t^{\mathrm{cat}}) & \bold{x}_t^{\mathrm{cat}}\neq\bold{m} \\
\mathrm{Cat}\left(\bold{x}_s^{\mathrm{cat}};\dfrac{(1-\alpha_s)\bold{m}+(\alpha_s-\alpha_t)\boldsymbol{\mu}_{\theta}^{\mathrm{cat}}(\bold{x}_t, t)}{1-\alpha_t}\right) & \bold{x}_t^{\mathrm{cat}}=\bold{m}
\end{cases}
となります. lossですが, 既存研究により離散化の解像度を上げることでtightなELBOを近似できるとされています. ここでもそれに則り, 連続時間における尤度の下限を最適化します.
\mathcal{L}_{\mathrm{cat}}(\theta, k)=\mathbb{E}_q\int_{t=0}^{t=1}\dfrac{\alpha_t'}{1-\alpha_t}\log\langle\boldsymbol{\mu}_{\theta}^{\mathrm{cat}}(\bold{x}_t, t), \bold{x}_0^\mathrm{cat}\rangle dt
ここで \alpha_t' は \alpha_t の一次導関数です.
Training
定式化が終わったのでどうやって訓練するかを考えます. 表データでは画像のRGBや同じ語彙空間を共有する単語のトークンなどとは異なり, それぞれのカラムが固有の分布を持っています. そのため, それぞれのデータに対してschedulerを適応的に学習します. schedulerの柔軟性と頑健性のトレードオフのバランスを撮るために2つの関数を設計します. power-mean scheduleとlog-linear categorical scheduleです.
power-mean schedule
ここでは \bold{\sigma}^{\mathrm{num}}(t)=[\sigma_{\rho_i}^{\mathrm{num}}(t)] と定義します. ここで \rho_i は各数値データに対する訓練可能なパラメータです. 任意の i\in\{1, \ldots,M_{\mathrm{num}}\} に対して
\sigma_{\rho_i}^{\mathrm{num}}(t)=\left(\sigma_{\min}^{\frac{1}{\rho_i}}+t(\sigma_{\max}^{\frac{1}{\rho_i}}-\sigma_{\min}^{\frac{1}{\rho_i}})\right)^{\rho_i}
とします.
log-linear categorical schedule
基本的には数値データと同じで \bold{\sigma}^{\mathrm{cat}}(t)=[\sigma_{\rho_{k_j}}^{\mathrm{cat}}(t)] と定義します. 違うのはここからで, 任意の j\in\{1, \ldots,M_{\mathrm{cat}}\} に対して
\sigma_{\rho_{k_j}}^{\mathrm{cat}}(t)=-\log(1-t^{k_j})
とします. 実装中では最初と最後のノイズレベルは固定します. すなわち \sigma_{i}^{\mathrm{num}}(0)=\sigma_{\min},\ \sigma_i^{\mathrm{num}}=\sigma_{\max} です.
これら2つのscheduleは論文内で突然登場して, 特に根拠のある (なにかしら引用で根拠づけされている)関数ではないです. この部分はOpenReviewの査読を見てみると指摘されていたりします.
目的関数
以前に確認したlossを変更することなくここで定義したパラメータも更新したいです. すると重みをつけてあげればよく,
\begin{align*}
\mathcal{L}_{\mathrm{TabDiff}}&(\theta, \rho, k)=\lambda_{\mathrm{num}}\mathcal{L}_{\mathrm{num}}(\theta, \rho)+\lambda_{\mathrm{cat}}\mathcal{L}_{\mathrm{cat}}(\theta, k) \\
&=\mathbb{E}_{t\sim U[0, 1]}\mathbb{E}_{(\bold{x}_t,\bold{x}_0)\sim q(\bold{x}_t,\bold{x}_0)}\left(\lambda_{\mathrm{num}}\|\boldsymbol{\mu}_{\theta}^{\mathrm{num}}(\bold{x}_t, t)-\boldsymbol{\varepsilon}\|_2^2+\dfrac{\lambda_{\mathrm{cat}}\alpha_t'}{1-\alpha_t}\log\langle\boldsymbol{\mu}_{\theta}^{\mathrm{cat}}(\bold{x}_t, t), \bold{x}_0^\mathrm{cat}\rangle\right)
\end{align*}
です. 訓練の詳細を以下に示します.
Eq. (6)は
q(\bold{x}_t\mid\bold{x}_0)=\mathrm{Cat}(\bold{x}_t;\alpha_t\bold{x}_0+(1-\alpha_t)\bold{m})
です.
Sampling
カテゴリデータは確認したように
p_{\theta}(\bold{x}_s^{\mathrm{cat}}\mid\bold{x}_t^{\mathrm{cat}})=\begin{cases}
\mathrm{Cat}(\bold{x}_s^{\mathrm{cat}};\bold{x}_t^{\mathrm{cat}}) & \bold{x}_t^{\mathrm{cat}}\neq\bold{m} \\
\mathrm{Cat}\left(\bold{x}_s^{\mathrm{cat}};\dfrac{(1-\alpha_s)\bold{m}+(\alpha_s-\alpha_t)\boldsymbol{\mu}_{\theta}^{\mathrm{cat}}(\bold{x}_t, t)}{1-\alpha_t}\right) & \bold{x}_t^{\mathrm{cat}}=\bold{m}
\end{cases}
でサンプリングすればいいのですが, 表データはカラム間の相関などもあり複雑な構造をしています. そのため, モデルがサンプリングの最中に誤差があれば修正することを期待します. そのため, ノイズ除去の各ステップにおいてreverse processを追加のforward processで再スタートさせる確率的なサンプラーを用います. この少し戻るという処理は目新しいものではなく, 例えば
https://openreview.net/forum?id=wFuemocyHZ
のような研究があります. そこで使われている図が明快なので引用すると,
Restart Sampling for Improving Generative Processesより引用
といった感じです. 今回はこれを発展させたものを用います. 各時刻 t において, まず小さな時刻を追加し (微小ではないです), t^{+}=t+\gamma_t t とします. その後, t から t^+ に向かってノイズ付与を行い \bold{x}_{t^+} を得ます. そして, t^{+} から t-1 に向かってdenoiseを行います.
アルゴリズムを示します.
Eq. (8)は
p_{\theta}(\bold{x}_s^{\mathrm{cat}}\mid\bold{x}_t^{\mathrm{cat}})=\begin{cases}
\mathrm{Cat}(\bold{x}_s^{\mathrm{cat}};\bold{x}_t^{\mathrm{cat}}) & \bold{x}_t^{\mathrm{cat}}\neq\bold{m} \\
\mathrm{Cat}\left(\bold{x}_s^{\mathrm{cat}};\dfrac{(1-\alpha_s)\bold{m}+(\alpha_s-\alpha_t)\boldsymbol{\mu}_{\theta}^{\mathrm{cat}}(\bold{x}_t, t)}{1-\alpha_t}\right) & \bold{x}_t^{\mathrm{cat}}=\bold{m}
\end{cases}
です.
条件付け
表データにおいて, missing value imputationは重要なタスクです. 拡散モデルを用いた表データ生成モデルはこのタスクもこなすことができます. 先行研究であるTabSynではRepaintの手法を用いることで解決をしていましたが, TabDiffでは条件付き生成とみなして行います. \bold{y}=\{\bold{y}^{\mathrm{num}},\bold{y}^{\mathrm{cat}}\} を与えられた特徴とし, \bold{x}=\{\bold{x}^{\mathrm{num}},\bold{x}^{\mathrm{cat}}\} で表される欠損値補間後のデータを予測したいとします. このとき \bold{x} は \bold{y} を条件として生成されるデータと解釈できます.
画像生成における条件付き生成は, 分類器を導入したclassifier guidanceがありましたが, 現在ではそれを用いないclassifier-free guidance (CFG)が主流です. TabDiffでもCFGを用います. w>0 をCFGの強さとすると, guided conditional sampleの分布は \tilde{p_{\theta}}(\bold{x}_t|\bold{y})\propto p_{\theta}(\bold{x}_t|\bold{y})p_{\theta}(\bold{y}|\bold{x}_t)^w で与えられます. ベイズの定理を用いると
\tilde{p_{\theta}}(\bold{x}_t|\bold{y})\propto p_{\theta}(\bold{x}_t|\bold{y})p_{\theta}(\bold{y}|\bold{x}_t)^w=p_{\theta}(\bold{x}_t|\bold{y})\left(\dfrac{p_{\theta}(\bold{x}_t|\bold{y})p(\bold{y})}{p_{\theta}(\bold{x}_t)}\right)^w=\dfrac{p_{\theta}(\bold{x}_t|\bold{y})^{w+1}}{p_{\theta}(\bold{x}_t)^w}p(\bold{y})^w
となります. p(\bold{y}) は \theta に依存しない定数なので無視すると, 対数確率は
\log\tilde{p_{\theta}}(\bold{x}_t|\bold{y})=(1+w)\log p_{\theta}(\bold{x}_t|\bold{y})-w\log p_{\theta}(\bold{x}_t)
を得ます. これは次のように変えられます. 数値データについては通常のCFG同様に
\tilde{\boldsymbol{\mu}_{\theta}^{\mathrm{num}}}(\bold{x}_t, y, t)=(1+w)\boldsymbol{\mu}_{\theta}^{\mathrm{num}}(\bold{x}_t, y, t)-w\boldsymbol{\mu}_{\theta}^{\mathrm{num}}(\bold{x}_t, t)
とできます. カテゴリデータについては
\log\tilde{p_{\theta}}(\bold{x}_{s}^{\mathrm{cat}}|\bold{x}_{t}^{\mathrm{cat}}\bold{y})=(1+w)\log p_{\theta}(\bold{x}_{s}^{\mathrm{cat}}|\bold{x}_{t}^{\mathrm{cat}}\bold{y})-w\log p_{\theta}(\bold{x}_{s}^{\mathrm{cat}}|\bold{x}_{t}^{\mathrm{cat}})
です.
実験
既存手法との比較とablationを行います.
実験設定
Adult, Default, Shoppers, Magic, Faults, Beijing, News, Diabetesの7つのデータセットを用います. 比較手法はCTGAN, TVAE, GOGGLE, GReaTに加え, 拡散モデルの手法であるTabDDPM, STaSy, CoDi, TabSynです. 評価は3つの観点から行います.
- Fidelity: 各カラムのdensityを見るShape, 実データとの異なるカラム間の相関を見るTrendといった旧来のものに加えて \alpha-Precisionと \beta-Recall とDetectionを用います. しかし, これら3指標はAppendixに結果があるので省略します.
- Downstream Tasks: Machine Learning Efficiency (MLE)とMissing Value Imputationを行います.
- Privacy: DCRを用いて学習データのコピーになっていないかを計測します.
全ての実験で20回のサンプリングを行います. 実験は1枚のNVIDIA RTX A4000 GPU (16GB) を用います.
結果
まずはShapeとTrendを確認します. OOMについてですが, STaSyはDiabetesでhigh cardinalityに起因するout of memoryが出たようです. また, GReaTはNewsでmaximum length limitによって実験ができていません. DiabetesでOOMと書かれている理由は不明です.
まず, 表1のShapeの結果を見ます. TabDiffは5つのデータセットで全ての手法に勝る結果となりました. これはTabDiffが個々のカラムに対して分布を維持できていることを示します. 続いて表2のTrendの結果を見ます. こちらは全てのデータで全ての手法に対して優位性が示されています. 特筆すべきはDiabetesで大きな改善がされていることです. これは10万くらいのサイズでカテゴリカラムは27, maxのcardinalityは716と重いデータセットです. このようなデータではOOMだったり結果が悪かったりしますが, TabDiffは他のデータ同様いい結果を出しています.
次に, 順番は前後しますがPrivacyの結果を見ます.
スコアが50に近ければ近いほどいいです. 結果を見ると概ね50%よりほんの少し大きい値になっています. ここからは個人的な意見ですが, AdultやDefaultでは既存手法の方がいいスコアなのに提案手法が青で書かれていることが気になります. typoだとここでは信じておきます. それはともかく, 全体としてかなり品質がいいことがわかります. ただ, 他の手法と比較して標準偏差が大きいように見えるので安定した生成ではなさそうです.
最後に, 下流タスクでの結果を確認します. まず, MLEです. 本当に本物データと品質が同じであれば, 下流タスク適用時のスコアは本物データと生成データのどちらを用いても同じです. XGBoostを用いて下流タスクのスコアを算出します. 指標は分類タスクがAUC, 回帰タスクがRMSEです.
結果を見るとTabDiffが一貫して最高性能であることがわかります. 注意すべき点として, 例えば生成モデルが訓練データをコピーするような学習を行った場合, MLEは優れた結果となります. そのため, MLE単体で評価することはできず, DCRであったり他の指標と併用することが大事です. ただ, 先ほど見たようにその他指標でもTabDiffは優れた結果を示していますので, ここでは問題ないと考えられます.
続いてMissing Value Imputationの結果を見ます. target columnを欠損しているとみなして生成を行います.
結果を見てみると, TabSynより優れていることがわかります. ただ, 多くの場合でAUCが90を超えており, やや頭打ち感があるのでデータを変えた方がいいかもしれないです. CFG scaleを0.0と0.6の場合で実験を行なっています. 例えば画像生成ではCFGを大きくするとCLIP Scoreは改善するがFIDは悪化するみたいな話がありますが, ここではそのようなことがないのかは気になります. なお, このあとablationに話が移りますがそこではそのような話はありません.
Ablation
SamplerとScheudleについてのablationが行われています.
まず, Samplerです. ここでは決定的 (Deterministic, Det)と確率的 (Stochastic, Sto)の比較を行います.
結果を見ると, noise schedulerを固定したときと学習させたときで, 確率的なsamplerでnoise schedulerを学習させたときが最良の結果であることがわかります.
続いて, noise scheudleについて確認します. 数値データは
\sigma_{\rho_i}^{\mathrm{num}}(t)=\left(\sigma_{\min}^{\frac{1}{\rho_i}}+t(\sigma_{\max}^{\frac{1}{\rho_i}}-\sigma_{\min}^{\frac{1}{\rho_i}})\right)^{\rho_i}
で表されるscheduleを採用していました. fixedとして \rho_i=7 とします. カテゴリデータは
\sigma_{\rho_{k_j}}^{\mathrm{cat}}(t)=-\log(1-t^{k_j})
で表されていましたが, k_j=1 でfixします. 先ほどの表からもわかりますが, fixさせない方がいい結果となります. 実際にAdultデータに対するtrain lossをplotしてみても
のようになり, fixさせない方が一貫して低いlossであることがわかります. \rho_i や k_j の固定する値はこれがベストなのかどうかについては触れられていませんでした.
思ったこと
- 正直TabSynが結構完成されていたと思うのでまだ改善できるんだとは思いました.
- メモリ (OOM)の話をするならTraining Timeとかも比較して欲しいなと思います. 表データ生成は基盤モデルがなく, 基本的には個々のデータに対して都度訓練を行う以前からのスタイルなので, 品質が良くてもそこに到達するまでの時間が長いと微妙な気持ちになります.
- 表データは結局どう扱うのがいいのでしょう. カテゴリ分布を仮定しているものとしていないものがあり, 毎回交互に出てきてる印象です. もしかしたらどっちでもいいのかもしれません. ただ, 実験で少し出てきたhigh cardinalityの状況だとカテゴリ分布の学習は大変になっていくことが予想されます. 画像ではピクセルの値が1違っても問題ないですが, カテゴリのような離散データはそうではないからです.
- scheduleのところでも述べましたが, いろいろ出てくるものが唐突な気がします.
参考文献
- Juntong Shi, Minkai Xu, Harper Hua, Hengrui Zhang, Stefano Ermon, and Jure Leskovec. Tabdiff: a multi-modal diffusion model for tabular data generation, 2024.
- Yilun Xu, Mingyang Deng, Xiang Cheng, Yonglong Tian, Ziming Liu, and Tommi S. Jaakkola. Restart sampling for improving generative processes. In Thirty-seventh Conference on Neural Information Processing Systems, 2023.
Discussion