📏

【生成AI】なぜ「対数尤度の最大化」ではなく「ELBOの最大化」が目的関数になるのか【VAE・拡散モデル】

2024/12/10に公開

はじめに

本記事は、生成AI Advent Calendar 2024の10日目の記事です。

画像生成AIを勉強する中で、確率分布をモデル化するために、対数尤度を最大化するように学習しますという話はよく聞く話だと思います。
私も、画像生成AIの一つである「拡散モデル」に関して、簡単な理論的説明の記事を書きました。

https://zenn.dev/asap/articles/4092ab60570b05
https://zenn.dev/asap/articles/8aaa896a02f168

理論的説明と言いましたが、ほとんど数式を使わずに簡単に解説したので、誰でもわかりやすく読めると思います。
(下の記事は少しだけ数式が増えるので、数式的な興味がある方だけがみていただければと思います)

上記の記事でも、説明しましたが、基本的に画像生成AIでは、自然画像の確率分布を学習しており、そのために対数尤度の最大化を行い、確率分布を学習しようとします。

しかしながら、どの参考書や技術記事を見ても、(拡散モデルやVAEでは)対数尤度ではなくELBOを代わりに最大化させるようにしていると思います。
今回は、「なぜ対数尤度の最大化ではなくELBOの最大化をするのか」についての記事になります。

結論を言ってしまうと、「VAEや拡散モデルの枠組みでは、対数尤度を直接計算することができないため」です。

参考文献

ゼロから作るDeep Learning ❺ ―生成モデル編

圧倒的な名著シリーズである「ゼロから作るDeep Learning」シリーズの5作目である、生成モデル編です。
この本では、最終的に拡散モデルを理解することを目的に、VAEや混合ガウスモデル、正規分布まで遡り、本当にゼロから中身を理解しながら実装をすることができる、大変素晴らしい書籍になっています。

本書籍単体でも、非常に分かりやすいですが、今記事のタイトルである「なぜ対数尤度の最大化ではなくELBOの最大化をするのか」の部分は、多少行間が広い部分がある(というか本当にELBOを利用しないといけないのか?)と感じたため、それを解消できれば幸いです。
(もちろん、その他書籍と比較すると、圧倒的に行間が狭いですが、私のような数弱には理解に時間がかかった部分であるので、その部分の行間を埋めれればと思います。将来の自分宛です)

妥協しないデータ分析のための 微積分+線形代数入門

本書の後半にもVAEやELBOについての簡単な記載があり、本記事を書くうえで表現を参考にさせていただきました。
確率分布のどの部分が計算可能で、どの部分が計算不可能なのかが常に明示されながら進んでいくので、初学者には非常に分かりやすい書籍でした!

特に、生成AIをただ使うだけでなく、ちゃんと理論を理解したいと思った場合、論文や理論書を理解するために「線形代数」や「微積分学」の事前知識が必要になります。
本書は、初学者が一番最初に読んで必要な基礎知識が習得できる良本になります。

拡散モデル データ生成技術の数理

拡散モデルの理論本と言えばコレ!と言えるほどおすすめの本です。
数式が増えるので、ゼロから作るDeep Learning ❺ ―生成モデル編妥協しないデータ分析のための 微積分+線形代数入門を読破した上(もしくは知識を持った上で)読むと理解しやすいと思います。
この本のいいところは、数式が増えて厳密な理解ができるようになるのはもちろんのこと、その数式を説明する行間が狭いため、非常に簡単に理解ができる点です。
また、難しい箇所などは図などを使って、視覚的に理解しやすいように工夫されているので、大学の教科書を読むのが苦手な人でも、理系であればついていけると思います!

対数尤度について

(画像)生成モデルがどのようにして、多様な画像を生成しているのかは、こちらの記事で説明していますが、簡単に振り返ります。

大量の自然画像が含まれているデータセットDを用意します。

自然画像一枚一枚をx_k(kは自然数)とすると、

D=\{x_1,x_2,......\}

と表すことができます。

これらの画像x_kがすべて、神が定めた自然画像の確率分布p(x)からサンプリングされて、現実世界に生み出されていると仮定します。
そう考えると、その確率分布を何らかの方法で再現することができれば、再現した確率分布から、サンプリングを行うことで、本物の自然画像に類似した画像を生成することができると考えることができます。

この再現確率分布が何らかのパラメータ\thetaにより制御されているとすると、再現確率分布はp_\thetaと書くことができます。

では、このパラメータをどのようにして決定するかというと、パラメータ\thetaが得られている時に、自然画像xの条件付き確率を最大化するように、パラメータ\thetaを更新することで、最適化が可能です。
(くどいようですが、この辺りの説明は過去の記事をご覧ください)

この条件付き確率は、p_\theta(x)もしくはp(x|\theta)と書くことができますが、後述の式変形の分かりやすさのためp_\theta(x)を記述することにします。
こうすることで、パラメータ\thetaを他のパラメータと区別することができるので、おそらく分かりやすくなるかなと思いました。

そして、この条件付き確率は「ある特定の自然画像が得られた際に、そのデータがある確率分布(モデルパラメータ)のもとで観測された確率」を示しており、特別に「尤度」と呼びます。

(画像)生成モデルでは、この尤度を最大化する、つまり「最尤推定」により、パラメータ\thetaを最適化しています。
具体的には以下の式です。

L(\theta; x_1, x_2, \dots, x_n) = \prod_{i=1}^{n} p_{\theta}(x_i)

さらに、最尤推定する際に、尤度は対数尤度に変換します。
具体的には以下の式です。

\ell(\theta; x_1, x_2, \dots, x_n) = \log L(\theta; x_1, x_2, \dots, x_n) = \sum_{i=1}^{n} \log p_{\theta}(x_i)

尤度の積は、データ数が増えれば増えるほど、容易にアンダーフローを発生させます。
したがって対数を取ることで、データ点同士の和の形に目的関数を変形させることができ、安定した学習を可能にします。
加えた、対数は単調増加関数になるため、対数尤度を最大化するパラメータと尤度を最大化するパラメータは完全に一致します。

対数尤度が計算できない理由

VAEや拡散モデルにおける尤度(確率分布)とは?

さて、ここまでで復習は完了です。
上述した通り、画像生成モデルなどでは、対数尤度\log p_{\theta}(x)を最大化するようなパラメータ\thetaします。

しかし、そのためには、そもそも\log p_{\theta}(x)の値を計算する必要があります。

VAEや拡散モデルでは、最終的な出力が生成された画像自体になります。

例えば、「PixelCNN」などの枠組みでは、最終的な出力として、256次元のクラス分類を実施しており、Pixelの値である0-255の値を離散値とした確率分布を出力しています。
このような枠組みであれば、モデルの出力自体が、確率分布になっているため、尤度の計算をするイメージが湧くと思います。

では、最終的な出力が生成された画像自体である場合は、どのような確率分布を仮定するかというと、平均がネットワークの出力\hat{x}、分散をIとする正規分布を仮定します。
後述する式変形でも説明しますが、このように設定することで、最終的な目的関数が、正解データxと生成データ\hat{x}の2乗誤差の最小化に帰着します。

つまり確率分布のパラメータは、平均\hat{x}のみで決定されます。
そして、この平均\hat{x}はネットワークの出力により決定づけられます。

これで、VAEや拡散モデルの枠組みであっても、対数尤度(確率分布)の最大化という議論を行うことができます。

では、実際に議論に入ります。

そもそも対数尤度って何者?

対数尤度の式をあらためてちゃんと見てみましょう。
対数尤度は下記のように表されます。

\ell(\theta; x_1, x_2, \dots, x_n) = \log L(\theta; x_1, x_2, \dots, x_n) = \sum_{i=1}^{n} \log p_{\theta}(x_i)

簡単化のために、ある一つのサンプルに注目した\log p_{\theta}(x_i)を考えます。

この式の意味は、あるネットワークパラメータ\thetaによってのみ条件づけられた、「正解データx_i」の対数尤度を表します。
つまり、このp_{\theta}(x_i)はネットワークパラメータ\theta以外では条件付けされていない、つまり独立な確率分布でなければなりません。

そう、潜在表現に対しても、です。

VAEも拡散モデルも、何らかの潜在表現z_iを入力として、Decoderが画像を復元します。
下記では単純化のために一旦VAEに着目しますが、拡散モデルであっても大まかな議論は同一です。

VAEが作る確率分布の正体

上述した通り、VAEでは何らかの潜在表現z_iを入力として、Decoderが画像を復元します。
したがって、最終的にネットワークの出力(+その後の正規分布)が作る確率分布は、

p_{\theta}(x_i|z_i)

となり、純粋な対数尤度p_{\theta}(x_i)とは異なります。

このように数式で表現すると、全く違うものだということが理解していただけると思います。

では、純粋な対数尤度を計算するために、式変形を試します。
また、今後はDecoderのパラメータ集合を\thetaとします。
つまり、p_{\theta}で表現される確率分布は、Decoderによって作られる確率分布になります。

対数尤度の式変形

対数尤度p_{\theta}(x_i)は確率分布の定義による式変形と、ベイズの定理による式変形を試すことができます。

まずは、確率分布の定義による式変形を考えます。すると対数尤度は下記のように式変形ができます。

p_{\theta}(x_i) = \int p_{\theta}(x_i \mid z_i) p(z_i) \, dz_i

すなわち、取りうる全ての潜在表現に対して、条件付け確率p_{\theta}(x \mid z)を計算して、総和を取る必要があります。これは連続的かつ高次元なzに対しては、現実的に不可能です。

また、忘れてはならないのは、上記の式は、一つのサンプルiに限定した式を提示しています。
本来の目的関数は、全データかつ対数尤度になるはずなので下記のようになります。

\ell(\theta; x_1, x_2, \dots, x_n) = \sum_{i=1}^{n} \log \int p_{\theta}(x_i \mid z_i) p(z_i) \, dz_i

上記の式を見て分かるように、「log-sum」の形になっています。「sum-log」であれば取り扱いの可能性がまだあるのですが、「log-sum」の形になると、解析的に解くことは困難です。
したがって、もし潜在変数zが離散的かつ少数であったとしても、交互最適化などの工夫なしに計算することはできません。

さて、式変形に戻ります。
続いては、ベイズの定理を用いて式変形してみます。すると下記のような式変形が成立します。

p_{\theta}(x_i) = \dfrac{p(z_i)p_{\theta}(x_i|z_i)}{p_{\theta}(z_i|x_i)}

さて、今出てきた下記の式の一つ一つが計算可能かどうかを見ていきます。

p_{\theta}(x_i) = \dfrac{p(z_i)p_{\theta}(x_i|z_i)}{p_{\theta}(z_i|x_i)}

まず、分子のp(z_i)は潜在表現z_iの事前分布です。ベイズ統計などと同様に、この事前分布は我々が決めて導入することができるので、問題ないです。
VAEではしばしば標準正規分布が用いられます。

続いて、分子のp_{\theta}(x_i|z_i)は、Decoderのパラメータ\thetaによる条件付けの環境下において、潜在変数z_iが与えられた場合の、正解データx_iの尤度です。
したがって、Decoderの出力\hat{x}と、\hat{x}を平均、分散をIとした正規分布でモデル化できます。

最後に、分母のp_{\theta}(z_i|x_i)はDecoderのパラメータ\thetaによる条件付けの環境下において、正解データx_iが与えられた場合の、潜在変数z_iの「事後確率」です。
これは、Decoder視点では計算ができないので、さらなる式変形が必要になります。

事後確率なので、ベイズの定理により変形を行います。

p_{\theta}(z_i|x_i) = \dfrac{p(z_i)p_{\theta}(x_i|z_i)}{p_{\theta}(x_i)} = \dfrac{p(z_i)p_{\theta}(x_i|z_i)}{\int p_{\theta}(x_i \mid z_i) p(z_i) \, dz_i}

したがって、分母にまた計算不能な形が現れてしまいました。

長々と書いてしまいましたが、結論として、尤度p_{\theta}(x_i)はどう式変形をしても、計算不能な箇所が現れてしまうため、計算できません。

今後の方針

では、どうするのか。
我々は、ELBOという武器を知っていますが、今回はそれを知らないものとして考えていきます。

軸となるアイデアは、
下記の式を考えて、対数尤度\log p_{\theta}(x_i)の代わりに、mを最大化することです。

\log p_{\theta}(x_i) \geq m

上記の式では、対数尤度\log p_{\theta}(x_i)は何らかの値mより必ず大きいので、mを最大化することで、間接的に対数尤度\log p_{\theta}(x_i)を最大化するというアイデアです。

そのときに、役に立つのがKLダイバージェンス、もしくは、イェンセンの不等式です。

イェンセンの不等式を利用する方が、式は簡潔に変形できますが、KLダイバージェンスを利用する方が式の意図が理解しやすいと思うので、KLダイバージェンスによる式変形を考えていくことにします。

KLダイバージェンスは2つの確率分布(f(x),g(x))を用意して、下記の式で表現できます。

D_{\text{KL}}(f(x) \| g(x)) = \int f(x) \log \frac{f(x)}{g(x)} \, dx

このKLダイバージェンスは、必ず0以上になることが知られています。
したがって、

\log p_{\theta}(x_i) = m + D_{\text{KL}}

の形にすることにより、\log p_{\theta}(x_i) \geq mを作ることができます。

簡単な補足

したがって、計算できない確率に関しては、このKLダイバージェンスの中に押し込めてやることを考えます。

さらに、KLダイバージェンスは、2つの確率分布の距離を算出する式であり、2つの確率分布が同じであれば0になります。
したがって、2つの確率分布が一致しているとき、mと対数尤度\log p_{\theta}(x_i)が一致することになり、mの最大化は、対数尤度\log p_{\theta}(x_i)の最大化に完全に一致します。

したがって、計算不能な確率分布をKLダイバージェンスに押し込める時は、計算不能な確率分布と高い精度で近似できる何らかの確率分布を用意して、押し込めた方が良いことがわかります。

では、上記の考えをもとに実際に対数尤度\log p_{\theta}(x_i)を式変形し、計算可能な分布のみで表現できるmを探しにいく旅に進みましょう!

実際に式変形をする

武器の整理

さて、ここからが本番です、
実際に対数尤度\log p_{\theta}(x_i)の式変形を行います。

まずは、ベイスの定理を用いて式変形をしておきます。

\log p_{\theta}(x_i) = \log \dfrac{p(z_i)p_{\theta}(x_i|z_i)}{p_{\theta}(z_i|x_i)} = \log p(z_i) + \log p_{\theta}(x_i|z_i) - \log p_{\theta}(z_i|x_i)

ここで計算不能な分布は分母のp_{\theta}(z_i|x_i)でした。
パラメータ\thetaはDecoderのパラメータなので、潜在変数z_iの事後分布を計算することはできません。

ここで、計算できない分布はKLダイバージェンスD_{\text{KL}}に押し込めることを考えましょう。
KLダイバージェンスD_{\text{KL}}は、下記のように定義されます。

D_{\text{KL}}(f(x) \| g(x)) = \int f(x) \log \frac{f(x)}{g(x)} \, dx

ここで、KLダイバージェンスD_{\text{KL}}は非負であり、二つの分布が一致してる時に0になります。ここで、このKLダイバージェンスがなるべく0に近づくように、p_{\theta}(z_i|x_i)に近づく分布を考えることにします。

押し込めたい分布はあくまで潜在変数z_iの事後分布です。
そこで、潜在変数z_iの分布を表現できる、何らかの分布q(z_i|x_i)を考えてみましょう。
(ここでは、一旦q(z_i|x_i)が計算可能かどうかは考えないことにします)

ここでq(z_i|x_i)は確率分布であるため、下記の2式が成立します。
(下の式は当たり前の式ですね)

\int q(z_i|x_i) \, dz_i = 1
\log q(z_i|x_i) - \log q(z_i|x_i) = 0

さて、上記の2式と対数尤度の式を用いると、

D_{\text{KL}}(q(z_i|x_i) \| p_{\theta}(z_i|x_i)) = \int q(z_i|x_i) \log \frac{q(z_i|x_i)}{p_{\theta}(z_i|x_i)} \, dz_i

を作ることができることに気づけるでしょうか。

では、実際に式変形をしていきましょう。

武器を用いて式変形

対数尤度\log p_{\theta}(x_i)を考えます。

ここで、武器1である

\int q(z_i|x_i) \, dz_i = 1

を導入します。
上記の式の値は1なので、対数尤度\log p_{\theta}(x_i)にかけてもよく、また、この式は潜在変数z_iに依存しないため、下記のように式変形が可能です。

\log p_{\theta}(x_i) = \int q(z_i|x_i) \, dz_i \log p_{\theta}(x_i) = \int q(z_i|x_i)\log p_{\theta}(x_i) \, dz_i
補足(依存しないについて)

続いて、ベイスの定理、対数法則により

\int q(z_i|x_i)\log p_{\theta}(x_i) \, dz_i = \int q(z_i|x_i)\log \dfrac{p(z_i)p_{\theta}(x_i|z_i)}{p_{\theta}(z_i|x_i)} \, dz_i
= \int q(z_i|x_i)\{\log p(z_i) + \log p_{\theta}(x_i|z_i) - \log p_{\theta}(z_i|x_i)\} \, dz_i

続いて、武器2である

\log q(z_i|x_i) - \log q(z_i|x_i) = 0

を導入します。
すると下記のように式変形ができます。

\int q(z_i|x_i)\{\log p(z_i) + \log p_{\theta}(x_i|z_i) - \log p_{\theta}(z_i|x_i)\} \, dz_i

武器2は0なので、括弧内に入れる

= \int q(z_i|x_i)\{\log p_{\theta}(x_i|z_i) + \log p(z_i) - \log p_{\theta}(z_i|x_i) + \{\log q(z_i|x_i) - \log q(z_i|x_i) \}\} \, dz_i

括弧内の計算順の変更

= \int q(z_i|x_i)\{\log p_{\theta}(x_i|z_i) + \{\log p(z_i) - \log q(z_i|x_i)\} - \{\log p_{\theta}(z_i|x_i) - \log q(z_i|x_i)\} \} \, dz_i

対数法則により、差を商に変換

= \int q(z_i|x_i)\log p_{\theta}(x_i|z_i) \, dz_i + \int q(z_i|x_i)\log \dfrac{p(z_i)}{q(z_i|x_i)} \, dz_i - \int q(z_i|x_i) \log \dfrac{p_{\theta}(z_i|x_i)}{q(z_i|x_i)} \, dz_i

対数法則により、真数の逆数を取り、符号を反転し、KLダイバージェンスの形を作る

= \int q(z_i|x_i)\log p_{\theta}(x_i|z_i) \, dz_i - \int q(z_i|x_i)\log \dfrac{q(z_i|x_i)}{p(z_i)} \, dz_i + \int q(z_i|x_i) \log \dfrac{q(z_i|x_i)}{p_{\theta}(z_i|x_i)} \, dz_i
= \mathbb{E}_{q(z_i|x_i)}[\log p_{\theta}(x_i|z_i)] - \mathrm{KL}\left(q(z_i|x_i) \parallel p(z_i)\right) + \mathrm{KL}\left(q(z_i|x_i) \parallel p_{\theta}(z_i|x_i)\right)

確率分布qの正体

さて、ここまで式変形してきましたが、この時点で計算不能な分布がp_{\theta}(z_i|x_i)であり、計算可能かどうかがわからない分布がq(z_i|x_i)となります。

ではここで、分布qについて考えてみたいと思います。

ここで記号を整理するとx_iはデータセットDに含まれる正解画像データのi番目のデータであり、z_iはそれに対応する潜在表現になります。
ここで、qは私たちが式変形のために独自に導入した確率分布であるため、x_iに条件づけられたz_iの分布であることを守れば、自由に設計することが可能な分布です。

一方で、iと言うのは、データセット内のデータ数分用意する必要があります。
したがって、各データx_iごとに、異なる分布を用意する必要があります。

したがって、q(z_0|x_0), q(z_1|x_1), q(z_2|x_2), \cdots , q(z_N|x_N)は全て異なる分布を用意する必要があります。
つまり、データセットD内に数億枚の画像がある場合は、数億個の分布を用意する必要があります。

これは現実的ではありませんが、私たちはたった一つのモデル化で、大量の入出力の写像を近似できる強力な武器を知っています。
はい、ニューラルネットワークです。

そこで、入力を正解画像データx_i、出力を対応する潜在変数z_iであり、パラメータが\psiであるニューラルネットワーク、すなわちEncoderを考えることで、下記のように確率分布qを書き換えられます。

q(z_i|x_i) = q_{\psi}(z_i|x_i)

さて、以上の結果から、対数尤度\log p_{\theta}(x_i)は下記のように式変形できることがわかりました。

\log p_{\theta}(x_i) = \mathbb{E}_{q_{\psi}(z_i|x_i)}[\log p_{\theta}(x_i|z_i)] - \mathrm{KL}\left(q_{\psi}(z_i|x_i) \parallel p(z_i)\right) + \mathrm{KL}\left(q_{\psi}(z_i|x_i) \parallel p_{\theta}(z_i|x_i)\right)

対数尤度とELBOの面白い関係

ELBOの導出

さて、まずは、当初の目的であったELBOを導出しましょう。

対数尤度を式変形した結果、項の一つに計算不能な項があります。
\mathrm{KL}\left(q_{\psi}(z_i|x_i) \parallel p_{\theta}(z_i|x_i)\right)です。

何度もしつこく記載していますが、このKLダイバージェンスに含まれる事後確率p_{\theta}(z_i|x_i)が計算不能です。

そこで、KLダイバージェンスが非負であるという特徴を利用して、下記のように不等式変形が可能です。

\log p_{\theta}(x_i) \geq \mathbb{E}_{q_{\psi}(z_i|x_i)}[\log p_{\theta}(x_i|z_i)] - \mathrm{KL}\left(q_{\psi}(z_i|x_i) \parallel p(z_i)\right)

この式の右辺をELBOと呼び、対数尤度の下限になっていることがわかります。

(補足)イェンセンの不等式によるELBOの導出

詳細な説明は割愛しますが、イェンセンの不等式とは、凸関数や、logなどの凹関数による写像を考えたときに、任意の点集合の変換後の重み付け和と、重み付け和後の変換の関係性を不等式で表す定理です。
(使う重みの総和は1とすると、凸関数の時は変換後の重み付け和のほうが大きくなり、凹関数では重み付け和後の変換のほうが大きくなります)

ちなみに、log関数におけるイェンセンの不等式は下記で表せます

\log \int q(z_i)f(z_i) \, dz_i \geq \int q(z_i) \log f(z_i)\, dz_i

これを用いると対数尤度は下記のように式変形できます。

\log p_{\theta}(x_i) = \log \int p_{\theta}(x_i \mid z_i) p(z_i) \, dz_i
= \log \int q_{\psi}(z_i|x_i) \dfrac{p_{\theta}(x_i \mid z_i)p(z_i)}{q_{\psi}(z_i|x_i)} \, dz_i

(ここで、イェンセンの不等式を利用)

\geq \int q_{\psi}(z_i|x_i) \log \dfrac{p_{\theta}(x_i \mid z_i)p(z_i)}{q_{\psi}(z_i|x_i)} \, dz_i
= \mathbb{E}_{q_{\psi}(z_i|x_i)}[\log p_{\theta}(x_i|z_i)] - \mathrm{KL}\left(q_{\psi}(z_i|x_i) \parallel p(z_i)\right) = \mathrm{ELBO}

となり、ELBOが簡単に導出できます。

しかし、KLダイバージェンスを利用して導出するほうが、式の意味は理解しやすいと思いますので、諸学者はそちらの導出を理解することをオススメします。

ELBOが対数尤度に近づく条件

対数尤度の代わりにELBOを最大化するとはいっても、ELBOと対数尤度ができるだけ近い値であったほうが、効率的に最適化ができるはずです。
したがって、どのような条件の時にELBOと対数尤度が近くなるのかを考えることは、意味があります。

これまでの議論から分かるように、ELBOが対数尤度に近づくのは、KLダイバージェンス\mathrm{KL}\left(q_{\psi}(z_i|x_i) \parallel p_{\theta}(z_i|x_i)\right)が0に近づく時です。

しかしながら、我々ができるのはあくまでELBOの最大化のため、このKLダイバージェンスは最適化対象に含まれていません。
したがって、ELBOを最大化させたときに、KLダイバージェンスがどのように変化するのかをみていく必要があります。

ここで改めて、対数尤度の式変形は下記のようになります。

\log p_{\theta}(x_i) = \mathrm{ELBO}_{\theta, \psi} + \mathrm{KL}\left(q_{\psi}(z_i|x_i) \parallel p_{\theta}(z_i|x_i)\right)

この時、式の左辺に注目すると、あくまで対数尤度\log p_{\theta}(x_i)自体に効いてくるパラメータは、Decoderのパラメータ\theta「のみ」であることがわかります。

つまり、Encoderのパラメータ\psiを変化させても、対数尤度\log p_{\theta}(x_i)自体の分布は変化しないことがわかります。
では、何が変わるのかというと、\mathrm{ELBO}_{\theta, \psi}の値が変わります。

つまり、Encoderのパラメータ\psiを最適化することで、対数尤度の値は一定の条件下で、\mathrm{ELBO}_{\theta, \psi}の値を大きくすることができると言うわけです。

ここで、対数尤度の式変形に戻り、
左辺が一定、右辺の第一項が大きくなる場合、第二項は必然的に「小さく」なります。

第二項はKLダイバージェンスであり、非負であるため、0に近づくと言うわけです。

したがって、ELBOを最大化するようにVAEの学習を進めることで、計算不能な分布p_{\theta}(z_i|x_i)q_{\psi}(z_i|x_i)で近似しています。

前述しましたが、Encoderは、平均と分散のパラメータを出力し、正規分布としてq_{\psi}(z_i|x_i)の分布を構築します。
一方で、計算不能な分布p_{\theta}(z_i|x_i)は、正規分布とは異なり非常に複雑怪奇な分布をしていることが予想されます。

したがって、このKLダイバージェンスは、0にはならないことが想定されますが、それでも、複雑怪奇なp_{\theta}(z_i|x_i)を単純な正規分布q_{\psi}(z_i|x_i)で近似し、可能な限りELBOと対数尤度を近づけています。

このように、計算不能な分布を正規分布という単純な分布で近似する手法を「変分近似(Variational Approximation)」と呼びます。
VAEのVはこのVariationalのVからきているそうです。

ELBOの分析

続いてELBOの各項に対して、詳細にみていきます。
改めて、ELBOを下記に提示します。

第一項を考える

ELBOの第一項は下記です。

\mathbb{E}_{q_{\psi}(z_i|x_i)}[\log p_{\theta}(x_i|z_i)]

これは、Decoderにおいて、潜在表現z_iが得られている条件下での、正解画像データx_iの条件付き対数尤度のEncoderの条件付き分布に基づいた期待値です。

深層学習系において、期待値とは大量のデータによって得られた結果の平均と捉えても、大きく外れず、またq_{\psi}(z_i|x_i)はEncoderによりモデル化されていると考えると、サンプル数1で期待値を近似した場合、下記のように式変形ができます。

\mathbb{E}_{q_{\psi}(z_i|x_i)}[\log p_{\theta}(x_i|z_i)] \approx \log \mathcal{N}(x_i; \hat{x_i}, I)

ただし、
x_iはデータセットDから得られた正解画像データのi番目
z_iはパラメータ\psiを持ち、入力をx_iとするEncoderから得られた、平均と分散パラメータにより構成された正規分布からサンプリングされた潜在表現
\hat{x_i}はパラメータ\thetaを持ち、入力をz_iとするDecoderから得られた出力画像データ
\mathcal{N}(x_i; \hat{x_i}, I)は、平均\hat{x_i}、分散Iとする正規分布における、正解画像データx_iの確率
とします。

したがって、ELBOの第一項を最大化したい場合は、\log \mathcal{N}(x_i; \hat{x_i}, I)を最大化すれば良いことがわかります。

では、まずは多変量正規分布の確率密度関数から見ていきます。
一般的な形を考えて、分散共分散行列が\Sigmaである時を考えると、下記のようになります。

\mathcal{N}(x_i; \hat{x_i}, \Sigma) = \frac{1}{(2\pi)^{d/2} |\Sigma|^{1/2}} \exp \left( -\frac{1}{2} (x_i - \hat{x_i})^T \Sigma^{-1} (x_i - \hat{x_i}) \right)

ただし、dx_iの次元数(画像で言うと「画素xチャンネル数」)です。
したがって、対数を取ると下記のようになります。

\log \mathcal{N}(x_i; \hat{x_i}, \Sigma) = -\frac{d}{2} \log (2\pi) - \frac{1}{2} \log |\Sigma| - \frac{1}{2} (x_i - \hat{x_i})^T \Sigma^{-1} (x_i - \hat{x_i})

ここで、特に今回の問題設定である共分散行列が単位行列の場合を考えます。
すると下記のようになります。

\log \mathcal{N}(x_i; \hat{x_i}, I) = -\frac{d}{2} \log (2\pi)- \frac{1}{2} (x_i - \hat{x_i})^T (x_i - \hat{x_i})
= -\frac{d}{2} \log (2\pi) - \frac{1}{2} \| x_i - \hat{x_i} \|^2

ここで、最適化問題を解く上で、第一項は定数であるため、無視することができます。
したがって、
ELBOを最大化する上でのELBOの第一項は、正解の画像データとDecoderの出力の画像データの2乗誤差を最小化する問題に帰着します。

第二項を考える

ELBOの第二項は下記です。

- \mathrm{KL}\left(q_{\psi}(z_i|x_i) \parallel p(z_i)\right)

ELBOを最大化するにあたり、第二項は負の項であるため、非負であるKLダイバージェンスを0に近づける必要があります。

ここで近づける分布は、Encoderの確率分布(q_{\psi}(z_i|x_i))と潜在表現z_iの事前分布p(z_i)です。

ここで、事前分布はベイズ統計などと同様に我々が適当に定義しても良い分布になります。
ただし、変な分布を設定すると精度が下がってしまうため、ある程度妥当な分布を指定する必要があります。

加えて、VAEは生成AIであることが求められます。
つまり、適切に潜在変数zをサンプリングしたら、自然画像が生成される必要があります。
したがって、通常のAEと異なり、Denceに潜在表現が分布している必要があります。

したがって、VAEでは平均0分散Iの標準正規分布を、潜在表現z_iの事前分布として、設定しています。

なぜ標準正規分布が事前分布になるのか

ではなぜVAEで平均0分散Iの標準正規分布が、潜在表現z_iの事前分布として設定されるのでしょうか。
それにはいくつかの理由があります。

1つ目は、正規分布は、ラグランジュの未定乗数法により、エントロピー最大化問題を解いた結果現れる分布となっており、平均0分散Iと言う条件下では、エントロピーが最大になる分布です。
したがって、平均0分散Iと言う条件以外で、偏りを持たない分布であるため、事前分布に余計な情報を付与しないと言う点で優れた分布になります。

2つ目は簡単に解析できる分布であるからです。
今回、ELBOの式の中には、KLダイバージェンスが含まれています。
KLダイバージェンスを、解析的に解くことのできる分布は実は少ないのですが、そのうちの一つが正規分布になります。

したがって、KLダイバージェンスを計算する上で、正規分布が選択肢になるのは必然です。
また、平均0分散Iとすることで、さらに計算を容易にしています。

3つ目は、潜在空間の正則化です。
潜在変数zの事前分布として標準正規分布を仮定することで、VAEは潜在空間全体にわたる潜在表現を構造化し、潜在表現のサンプルが標準正規分布に従うように正則化されます。
したがって、生成AIとして利用する際に、標準正規分布からサンプリングされた潜在表現を利用することで、意味のある自然画像が再構成されることが、保証されやすくなります。

KLダイバージェンスを解析的に解く

さて、続いては、ELBO第二項のKLダイバージェンスを解析的に解きます。
まず、2つの正規分布間でのKLダイバージェンスは次の式で表されることがわかっています。
(これは認めてもらえると嬉しいです)

\mathrm{KL}(q \parallel p) = \frac{1}{2} \left( \mathrm{tr}(\Sigma_p^{-1} \Sigma_q) + (\mu_p - \mu_q)^T \Sigma_p^{-1} (\mu_p - \mu_q) - k + \log \frac{\det \Sigma_p}{\det \Sigma_q} \right)

ただし、2つの正規分布は下記を想定しています。

\mathcal{N}(\mu_p, \Sigma_p), \mathcal{N}(\mu_q, \Sigma_q)

では、今回のKLダイバージェンスを考えます。
すると下記のようになります。

\mathrm{KL}(q_{\psi}(z_i|x_i) \parallel p(z_i)) = \frac{1}{2} \sum_{j=1}^D \left( \sigma_j^2 + \mu_j^2 - 1 - \log \sigma_j^2 \right)

ただし、Dは潜在表現z_iの次元数です。
上記の式から、Encoderが出力する平均と分散パラメータから、第二項が計算できることがわかります。

最終的に

最終的にELBOは下記のような式に帰着します。

\mathrm{ELBO} \approx - \frac{1}{2} \| x_i - \hat{x_i} \|^2 - \frac{1}{2} \sum_{j=1}^D \left( \sigma_j^2 + \mu_j^2 - 1 - \log \sigma_j^2 \right) + \mathrm{const}

VAEから拡散モデルへ

さて、ここまでお疲れ様でした。
ここまででVAEの目的関数である、対数尤度の最大化からELBOの最大化まで理解していただけたと思います。

しかし、ここまでの議論はVAEでの議論をしていました。
ここからは拡散モデルについて考えてみたいと思います。

しかし、拡散モデルの方が簡単です。
なぜなら、Encoderに学習可能パラメータが存在しないからです。

一旦、下記にVAEにおける対数尤度の式変形を再掲します。

拡散モデルの構造は、基本的に多層VAEと呼ばれるものに酷似しています。
違うのは、Encoderにパラメータはなく、潜在表現に対して、同一の処理を繰り返し実施することで、次の段階の潜在表現が得られます。
(詳しい内容はこちらの記事をご覧ください)

逆に言えば、VAEにおける対数尤度の式変形に対して、Encoderのパラメータをなくしたもの考えれば良いです。

そして、パラメータがないのであれば、最適化問題において無視ができます。

詳細な説明は、著者が疲れたため割愛しますが、最終的に生成画像と正解画像の2乗誤差項のみが残るため、拡散モデルにおいては目的関数が生成画像と正解画像の2乗誤差の最小化のみになります。

わかりやすいですね。

まとめ

ここまで読んでくださってありがとうございました!

この辺りの議論がわかると、生成AI系の目的関数がわかるようになるので、論文が読みやすくなるかなと思います。
ぜひみなさんのお役に立てれば幸いです!

Discussion