🐲

【論文解説】Unified Latents (UL): 拡散モデルで潜在変数を「正しく」学習する方法

に公開

拡散モデルをより効率化しようという意気込みから生まれた「潜在空間(Latent Space)」での処理は、画像生成技術を実用レベルまで押し上げました。Stable Diffusionや最近のFluxとQwen-Imageなど、多くのモデルが画像をVAE(変分オートエンコーダ)に通し、圧縮された潜在変数上で学習と推論を行っています。

潜在空間を使う利点は主に2つあります。

  1. 画像を圧縮することで、高解像度(1MPなど)の生成に計算量的に対応できる。
  2. 潜在空間をある程度ガウス分布に従わせることで、拡散モデル(U-NetやDiT)の入出力が扱いやすくなる。

しかし、ここには 「表現力と潜在空間の難易度のトレードオフ」 という重大なジレンマが潜んでいます。
簡単な空間(綺麗なガウス分布)を作れば拡散モデルは学習しやすくなりますが、元の画像の細部は復元できなくなります。逆に、複雑な空間を作れば画像の再構成は完璧になりますが、拡散モデルの学習が崩壊します。

従来のモデルは、このバランスを「開発者の直感と経験(ヒューリスティック)」で調整していました。
今回紹介する論文 『Unified Latents (UL): How to train your latents』 は、この課題に対し、「VAEのKLダイバージェンス(正則化)も、画像の復元(デコーダ)も、全部拡散モデルに任せてしまおう!」 という非常にエレガントな解決策を提案しています。


1. VAEと拡散モデルのおさらい:全てはELBOとKLダイバージェンス

ULの凄さを理解するために、まずはVAEとDDPM(拡散モデル)の根底にある数学を少しだけおさらいしましょう。長ったらしい証明は省きますが、どちらも「ELBO(変分下界)」と「KLダイバージェンス」を最適化しているという点が重要です。

VAE(変分オートエンコーダ)の学習

VAEは、データ x の対数尤度 \log p(x) を最大化したいのですが、直接計算できないため、代わりにELBO(変分下界) を最大化します。

\log p_\phi(x) = \text{ELBO}(\theta, \phi) + D_{KL}(q_\theta(z|x) \parallel p_\phi(z|x))

KL項は常に0以上なので、ELBOは尤度の下限になります。このELBOを計算可能な損失関数に変形すると以下のようになります。

\text{ELBO} = \underbrace{\mathbb{E}_{z \sim q_\theta(z|x)} [\log p_\phi(x|z)]}_{\text{再構成誤差}} - \underbrace{D_{KL}(q_\theta(z|x) \parallel p(z))}_{\text{正則化項}}
  • 再構成誤差: デコーダが潜在変数 z から元の画像 x をどれだけ復元できるか。
  • 正則化項 (KL): エンコーダの出力分布が、事前分布(標準正規分布)から離れすぎないようにするペナルティ。

この「KL項の重み」をどう設定するかが、先述のジレンマを生んでいました。

DDPM(拡散モデル)への進化

拡散モデルは、VAEのエンコーダを「固定のノイズ付加プロセス(フォワードプロセス)」に置き換え、デコーダを「ノイズ除去プロセス(逆過程)」として学習するモデルとみなせます。

DDPMの学習目標も、実はVAEと同じ変分下界から導かれます。
逆過程の1ステップをニューラルネットワークで近似する際、条件付きKLダイバージェンスを最小化します。

\mathcal{L}_{\text{diffusion}}(\phi) = \sum_{i=2}^L \mathbb{E}_{p(x_i|x_0)} \left[ D_{KL}\Big(p(x_{i-1}|x_i, x_0) \parallel p_\phi(x_{i-1}|x_i)\Big) \right]

「スタート地点(元の画像 x_0)」と「現在地(ノイズ画像 x_i)」が分かっていれば、その間の分布はガウス分布として解析的に計算可能です。結果として、この複雑なKLダイバージェンスは、単純なノイズ予測の二乗誤差(MSE) に帰着します。

\mathcal{L}_{\text{DDPM}} = \mathbb{E}_{t, x_0, \epsilon} \left[ \lambda_t \| \epsilon - \epsilon_\phi(x_t, t) \|_2^2 \right]

2. Unified Latents (UL) の全体像

ここからが本題です。ULは通常のVAEと異なり、以下の3つのモデルを同時に学習させます。

  1. エンコーダ (E_\theta): 画像 x から、決定論的な潜在変数 z_{clean} を出力する。
  2. 事前モデル / Prior (P_\theta): z_{clean} にノイズを加えたものをデノイズする潜在空間の拡散モデル
  3. デコーダ (D_\theta): わずかにノイズが乗った潜在変数 z_0 を条件として、画像 x を復元するピクセル空間の拡散モデル

Figure 1: Schematic overview of our model
Figure 1: ULの全体像。エンコーダ、Prior(拡散モデル)、デコーダ(拡散モデル)が連動する。

通常のVAEはエンコーダに「平均」と「分散」を出力させますが、ULのエンコーダは単一の z_{clean} しか出力しません。その代わり、Priorモデルの「最小ノイズレベル(\lambda(0)=5)」に相当する固定のガウスノイズz_{clean} に足して z_0 とします。これにより、エンコーダの学習が不安定になるのを防いでいます。

なぜ拡散モデル(Prior)でKLダイバージェンスを置き換えられるのか?

この論文の最大の醍醐味です。
ULでは、VAEのKL項における事前分布 p(z) を、標準正規分布ではなく学習可能な拡散モデル p_\theta(z_0) に置き換えます。

先ほどDDPMのおさらいで見たように、拡散モデルの対数尤度(ELBO)は「全ノイズレベルにおけるMSEの積分」として展開できます。したがって、VAEのKL項は次のように書き換えられます(論文 Eq. 3)。

D_{KL}(q(z_0|x) \parallel p_\theta(z_0)) \le \mathbb{E}_{t \sim \mathcal{U}(0,1)} \left[ w(\lambda_t) \parallel z_{clean} - \hat{z}(z_t, \theta) \parallel^2 \right]

つまり、「潜在変数が事前分布に従っているか(KL)」という抽象的なペナルティが、「Priorモデルがその潜在変数をどれくらい正確にデノイズできるか(MSE)」という具体的な損失に完全に置き換わるのです!
Priorモデルが簡単にデノイズできる潜在変数ほど「良い潜在変数」として評価されます。

論文には強調されているのは拡散モデルの重み付けだ。潜在空間を上手く調整するためにこのPriorモデルは全部の時刻 tに同じ重みをつける必要がある。

デコーダによる情報量のコントロール

「じゃあ、潜在空間の情報量はどうやって調整するの?」という疑問が湧くと思います。
ここで活躍するのが、もう一つの拡散モデルであるデコーダです。

デコーダの再構成ロスも、MSEやGANロスではなく拡散モデルの損失(Eq. 4)を用います。

-\log p_\theta(x|z_0) \le \mathbb{E}_{t \sim \mathcal{U}(0,1)} \left[ w_x(\lambda_x(t)) \parallel x - \hat{x}(x_t, z_0, \theta) \parallel^2 \right]

ここで、デコーダの損失に対して以下の重み付けを行います。

w_x(\lambda_t) = c_{lf} \cdot \text{sigmoid}(b - \lambda_t)

論文のFigure 3でも示されている通り、この式には潜在空間の情報量(ビットレート)を調整するための2つの重要なパラメータが含まれています。

  1. シグモイドバイアス (b)
    これは、どのノイズレベルの損失を割り引くかという「カーブの形状」を決定します。低ノイズレベル(画像の微細なテクスチャや高周波成分)の損失を意図的に割り引くことで、デコーダは「細かいディテールは損失として怒られないから、潜在変数 z_0 に頼らず、自分の拡散モデルとしての表現力で適当にでっち上げよう」と判断します。これにより、潜在変数には大まかな構図(低周波成分)だけが保持されるようになります。

  2. ロスファクター (c_{lf})
    これは、デコーダ損失全体の「スケール(重み)」を決定します。強力なデコーダ(拡散モデル)を用いると、モデルが潜在変数を完全に無視して画像を生成しようとする現象(Posterior Collapse)が起きやすくなります。そこで、c_{lf} を1より大きく(論文では1.3〜1.7程度)設定し、デコーダ側の損失を全体的に底上げします。これは相対的にPriorモデル(KL項)のペナルティを弱めることと同義であり、デコーダが潜在変数からしっかりと情報を引き出すように強制します。結果として、c_{lf} を大きくするほど潜在変数の情報量(ビットレート)は上がります。

つまり、バイアス b で「細かい情報を捨てさせる基準」を作り、ロスファクター c_{lf} で「潜在変数をどれだけ強く使わせるか」を底上げするという仕組みです。
ULは、この2つのパラメータの組み合わせによって、従来のVAEでは困難だった「潜在変数の情報量(再構成の綺麗さ)」と「Priorモデルの学習のしやすさ」のトレードオフを、数学的かつ直感的にコントロールすることに成功しています。


3. 実装のイメージと学習フロー

論文のAlgorithm 1 Training Unified Latentsを参考に学習コードは結構簡単に実装できます。

def train_step(x):
    # 1. エンコード (決定論的)
    z_clean = encoder(x)

    # 2. Priorモデルの損失計算 (KL項の代わり)
    t_prior = sample_uniform_t()
    z_t = add_noise(z_clean, t_prior) # 潜在変数にノイズを付加
    z_pred = prior(z_t, t_prior)
    # Priorは一様重み付けのMSE
    loss_prior = mse(z_clean, z_pred)

    # 3. デコーダモデルの損失計算 (再構成ロスの代わり)
    # 最小ノイズレベル(log-SNR=5)の固定ノイズを足す
    z_0 = add_fixed_noise(z_clean)

    t_dec = sample_uniform_t()
    x_t = add_noise(x, t_dec) # 画像にノイズを付加
    # z_0 を条件(Cross-Attention等)として画像を予測
    x_pred = decoder(x_t, t_dec, encoder_hidden_states=z_0)

    # デコーダはシグモイド重み付けのMSE
    weight_dec = c_lf * sigmoid(b - log_snr(t_dec))
    loss_dec = mse(x, x_pred) * weight_dec

    # 4. 合計ロスで最適化
    loss = loss_prior + loss_dec
    loss.backward()

2ステージの学習

理論上は、このままPriorモデルを使って画像を生成できるはずですが、論文によると「Priorモデルは一様重み付けで学習しているため、生成品質がイマイチ」とのこと。
そのため、Stage 1でエンコーダとデコーダを学習させた後、それらを凍結し、Stage 2としてPriorモデル(ベースモデル)だけをシグモイド重み付けで再学習させるアプローチをとっています。


4. 実験結果:圧倒的な学習効率

ImageNet-512における学習コストと生成品質(FID)のトレードオフを示したのが以下の図です。

Figure 4: FID vs. training cost
Figure 4: ImageNet-512におけるFID vs. 学習コスト。UL(緑)が他の手法を圧倒している。

ULは、従来のStable DiffusionのVAEを用いたベースラインや、他の最新手法(DiTなど)と比較して、同じ計算量で圧倒的に低いFIDを達成しています。最も効率的な事前学習アプローチであることが証明されました。


5. 個人の感想と著者のこぼれ話

この論文、個人的に非常に面白かったです。最近は「VAEを捨ててピクセル空間で直接学習しよう(JITなど)」というトレンドもありますが、潜在空間の計算効率の高さは捨てがたい。そこに対して「全部拡散モデルで解決する」という力技かつエレガントな手法を持ってきたのは見事です。

論文を読んでいて、著者の人間味や苦労が垣間見えたポイントもありました。

  • 「理論上は1チャンネルで十分」の罠
    序盤で「理論上、連続変数は無限の情報をエンコードできるから1チャンネルで十分」と言いつつ、「現実は浮動小数点の精度(16-32bit)のせいで無理だ」とぼやいています。数学の理想とエンジニアリングの現実ですね。結局普通の16と32がベストらしい。
  • FIDスコアへの不満
    Ablation Studyで、ImageNet以外のデータで学習させると再構成FID(rFID)が劇的に悪化したのに、生成品質(gFID)は良かったそうです。著者は「FIDという指標は、人間の目には見えない高周波の微細な統計の違いに過敏すぎるのではないか」と愚痴をこぼしています。画像生成AI研究者あるあるの悩みです(笑)。大学の教授も良く似ていることを言うから分かります。

Google DeepMindからの研究ということもあり、アイデアは素晴らしい反面、非公式データセットでの実験が多く「もう少し詳細な設定を知りたいな」と思う部分もありました。また、推論時にはPriorとデコーダの両方で拡散モデルの多ステップサンプリングが必要になるため、推論速度は従来のLDMより劇的に遅くなるという明確な弱点(Limitations)も抱えています。2つの拡散モデルの推論が必要になるからだ。

とはいえ、潜在空間の設計を「ロスファクター(c_{lf})」という一つのパラメータで数学的に制御できるようにしたこのアプローチは、今後の基盤モデル開発にこの技術は使われたら結構面白いと思います。

Discussion