iTranslated by AI
[Generative AI] Why Maximize ELBO instead of Log-Likelihood? [VAE & Diffusion Models]
Introduction
This article is the 10th-day entry for the Generative AI Advent Calendar 2024.
While studying image generation AI, you often hear that to model a probability distribution, the model is trained to maximize the log-likelihood.
I have also written articles providing simple theoretical explanations regarding "Diffusion Models," one of the image generation AI architectures.
I mentioned theoretical explanation, but I explained it simply without using many mathematical formulas, so I believe anyone can read and understand it easily.
(The second article has a bit more math, so it is mainly for those interested in the mathematical aspects.)
As explained in the articles above, image generation AI basically learns the probability distribution of natural images. To do this, it attempts to learn the distribution by maximizing the log-likelihood.
However, in most textbooks and technical articles (regarding Diffusion Models or VAEs), you will find that the ELBO is maximized instead of the log-likelihood.
This article explores "why we maximize the ELBO instead of the log-likelihood."
The conclusion is that "in the framework of VAEs and Diffusion Models, the log-likelihood cannot be calculated directly."
References
Deep Learning from Scratch ❺ — Generative Models
This is the fifth installment of the "Deep Learning from Scratch" series, which is an overwhelmingly masterpiece series.
With the goal of eventually understanding Diffusion Models, this book allows you to implement while understanding from scratch, going back to VAEs, Gaussian Mixture Models, and even the Normal Distribution. It is a truly wonderful book.
Even on its own, the book is very clear, but I felt the section regarding "why we maximize ELBO instead of log-likelihood" (or rather, why we must use ELBO) had some gaps between the lines. I hope this article can bridge those gaps.
(Of course, compared to other books, the gaps are already very narrow, but for someone like me who is weak at math, it took time to understand, so I want to fill those gaps. This is also for my future self.)
Introduction to Calculus + Linear Algebra for Uncompromising Data Analysis
The latter half of this book contains brief descriptions of VAEs and ELBO, and I referred to its phrasing while writing this article.
It was very easy for beginners to understand because it always explicitly states which parts of the probability distribution are computable and which are not!
Especially if you want to understand the theory properly rather than just using generative AI, you will need prior knowledge of "Linear Algebra" and "Calculus" to understand papers and theoretical texts.
This is an excellent book for beginners to acquire the necessary foundational knowledge first.
Mathematics of Diffusion Models: Data Generation Technology
This is the definitive book for the theory of Diffusion Models.
Since it uses more formulas, it might be easier to understand after finishing Deep Learning from Scratch ❺ — Generative Models or Introduction to Calculus + Linear Algebra for Uncompromising Data Analysis (or having equivalent knowledge).
The great thing about this book is that while it uses more formulas for a rigorous understanding, the explanation between those formulas is narrow, making it very easy to comprehend.
Also, difficult parts are supplemented with diagrams for visual understanding, so even those who find university textbooks challenging should be able to follow along if they have a STEM background!
(Book links are Amazon affiliate links)
About Log-Likelihood
I explained how (image) generation models create diverse images in this article, but let's take a quick look back.
Prepare a dataset
If each natural image is
Assume that all these images
From this perspective, if we can reproduce that probability distribution by some method, we can sample from it to generate images similar to real natural images.
If we assume this reproduction probability distribution is controlled by some parameter
How do we determine this parameter? Optimization is possible by updating the parameter
(To be tedious, please refer to the previous article for this explanation.)
This conditional probability can be written as
I thought this would make it easier to distinguish
This conditional probability represents "the probability that data was observed under a certain probability distribution (model parameters) when a specific natural image was obtained," and is specifically called "likelihood."
In (image) generation models, the parameter
Specifically:
Furthermore, when performing Maximum Likelihood Estimation, the likelihood is converted to log-likelihood.
Specifically:
The product of likelihoods easily causes underflow as the number of data points increases.
Therefore, by taking the logarithm, the objective function can be transformed into a sum of data points, enabling stable learning.
Additionally, since the logarithm is a monotonically increasing function, the parameters that maximize the log-likelihood are identical to those that maximize the likelihood.
Why the log-likelihood cannot be calculated
Now, the review is complete.
As mentioned above, in image generation models and others, parameters
However, to do that, it is first necessary to calculate the value of
What is likelihood (probability distribution) in VAEs and diffusion models?
In VAEs and diffusion models, the final output is the generated image itself.
For example, in a framework like "PixelCNN," the final output is a 256-dimensional classification, outputting a probability distribution where pixel values from 0 to 255 are treated as discrete values. In such a framework, since the model output itself is a probability distribution, it is easy to imagine how to calculate the likelihood.
So, what kind of probability distribution is assumed when the final output is the generated image itself? We assume a normal distribution where the mean is the network output
As explained in the formula derivations later, by setting it up this way, the final objective function boils down to minimizing the squared error between the ground truth data
In other words, the parameters of the probability distribution are determined solely by the mean
And this mean
This allows us to discuss the maximization of log-likelihood (the probability distribution) even within the framework of VAEs and diffusion models.
Now, let's dive into the actual discussion.
What exactly is log-likelihood?
Let's take another proper look at the log-likelihood formula.
Log-likelihood is expressed as follows:
For simplicity, let's consider
The meaning of this expression is the log-likelihood of the "ground truth data
In other words, this
Yes, even with respect to latent representations.
In both VAEs and diffusion models, the decoder reconstructs an image using some latent representation
Below, for simplicity, we will focus on VAEs, but the general discussion is roughly the same for diffusion models.
The true nature of the probability distribution created by VAE
As mentioned above, in a VAE, the decoder reconstructs an image using some latent representation
Therefore, the probability distribution finally created by the network output (+ the subsequent normal distribution) is:
This differs from the pure log-likelihood
Expressing it mathematically like this helps in understanding that they are completely different.
Now, let's try transforming the formula to calculate the pure log-likelihood.
Also, from here on, let the set of decoder parameters be
In other words, the probability distribution represented by
Log-likelihood formula transformation
Log-likelihood
First, let's consider the transformation based on the definition of a probability distribution. Then, the log-likelihood can be transformed as follows:
That is, we need to calculate the conditional probability
Also, it's important to remember that the formula above is presented for a single sample
The actual objective function should involve the log-likelihood for all data, as follows:
As you can see from the expression above, it's in the form of "log-sum." While "sum-log" might still be manageable, the "log-sum" form is difficult to solve analytically.
Therefore, even if the latent variable
Now, back to the formula transformation.
Next, let's transform it using Bayes' theorem. The following transformation holds:
Now, let's check if each part of the following expression is computable.
First, the
In VAEs, a standard normal distribution is often used.
Next,
Therefore, it can be modeled as a normal distribution with the decoder output
Finally, the denominator
Since this cannot be calculated from the decoder's perspective, further transformation is required.
Since it's a posterior probability, we transform it using Bayes' theorem.
Therefore, an uncomputable form has appeared in the denominator again.
I've written at length, but in conclusion, the likelihood
Future Direction
So, what do we do?
We know the weapon called ELBO, but let's assume for now that we don't know it.
The core idea is to consider the following equation and maximize
In the expression above, since the log-likelihood
In doing so, what helps is KL divergence or Jensen's inequality.
Using Jensen's inequality allows for a more concise transformation, but since using KL divergence makes the intent of the expression easier to understand, we will consider the transformation using KL divergence.
KL divergence can be expressed by providing two probability distributions (
It is known that this KL divergence is always 0 or greater.
Therefore, by using the form:
of the form, we can create
Simple Supplement
:::
Therefore, for probabilities that cannot be calculated, we consider pushing them into this KL divergence.
Furthermore, KL divergence is a formula that calculates the distance between two probability distributions, and it becomes 0 if the two distributions are identical.
Therefore, when the two distributions match,
Thus, when pushing uncomputable probability distributions into KL divergence, it's clear that it's better to prepare some probability distribution that can approximate the uncomputable one with high accuracy and push it in.
Now, based on the above ideas, let's proceed on the journey to transform the log-likelihood
Actually Performing the Formula Transformation
Organizing Our Tools
Now, this is where the real work begins. We will actually perform the formula transformation of the log-likelihood
First, let's transform the expression using Bayes' theorem.
As mentioned earlier, the uncomputable distribution was the denominator
Let's consider pushing this uncomputable distribution into the KL divergence
KL divergence
The distribution we want to "push in" is ultimately the posterior distribution of the latent variable
Since
Now, do you see that by using these two equations and the log-likelihood formula, we can construct the following?
Let's proceed with the actual transformation.
Formula Transformation Using Our Tools
Let's consider the log-likelihood
First, we introduce Tool 1:
Since the value of this expression is 1, we can multiply it by the log-likelihood
Supplement (Regarding independence)
:::
Next, according to Bayes' theorem and the laws of logarithms:
Next, we introduce Tool 2:
Using this, the formula can be transformed as follows:
Since Tool 2 equals 0, we can add it inside the brackets:
Changing the calculation order within the brackets:
Converting differences into quotients using the laws of logarithms:
Taking the reciprocal of the arguments and flipping signs to form the KL divergence structure:
$$
= \mathbb{E}{q(z_i|x_i)}[\log p{\theta}(x_i|z_i)] - \mathrm{KL}\left(q(z_i|x_i) \parallel p(z_i)\right) + \mathrm{KL}\left(q(z_i|x_i) \parallel p_{\theta}(z_i|x_i)\right)
The Identity of the Probability Distribution q
We have been transforming the formula, but at this point, the uncomputable distribution is
Now, let's consider the distribution
Organizing the symbols,
On the other hand, the index
While preparing hundreds of millions of distributions for a dataset with hundreds of millions of images is not realistic, we have a powerful tool that can approximate mappings of large amounts of inputs and outputs with a single model.
Yes, neural networks.
So, by considering a neural network with ground truth image data
From the results above, we find that the log-likelihood
Interesting Relationship Between Log-Likelihood and ELBO
Derivation of ELBO
Now, let's first derive the ELBO, which was our initial goal.
As a result of transforming the log-likelihood formula, there is one uncomputable term:
As mentioned repeatedly, the posterior probability
Therefore, by utilizing the fact that KL divergence is non-negative, the following inequality transformation is possible:
This right-hand side of the expression is called the ELBO (Evidence Lower Bound), and it is clearly the lower bound of the log-likelihood.
(Supplement) Derivation of ELBO using Jensen's inequality
I will omit the detailed explanation, but Jensen's inequality is a theorem stating the relationship between the weighted sum of values after transformation by a convex function or a concave function (like log) and the transformation of the weighted sum of values.
(Assuming the sum of the weights is 1, the weighted sum after transformation is greater for convex functions, while for concave functions, the transformation of the weighted sum is greater.)
Incidentally, Jensen's inequality for the log function can be expressed as follows:
Using this, the log-likelihood can be transformed as follows:
(Using Jensen's inequality here)
This makes deriving the ELBO simple.
However, deriving it using KL divergence makes the intent of the formula easier to understand, so I recommend that beginners understand that derivation instead.
Conditions for ELBO to approach log-likelihood
Even if we maximize the ELBO instead of the log-likelihood, optimization should be more efficient if the ELBO and log-likelihood are as close as possible. Therefore, it is meaningful to consider under what conditions the ELBO and log-likelihood become close.
As you can see from the discussion so far, the ELBO approaches the log-likelihood when the KL divergence
However, since all we can do is maximize the ELBO, this KL divergence is not included in the optimization target.
Therefore, we need to see how the KL divergence changes as we maximize the ELBO.
Recalling the transformation of the log-likelihood:
Looking at the left side of the equation, we can see that the parameters affecting the log-likelihood
In other words, changing the encoder parameters
What changes is the value of
This means that by optimizing the encoder parameters
Returning to the log-likelihood transformation, if the left side is constant and the first term on the right becomes larger, the second term must necessarily become "smaller."
Since the second term is a KL divergence and is non-negative, it approaches 0.
Therefore, by proceeding with VAE training to maximize the ELBO, the uncomputable distribution
As mentioned before, the encoder outputs mean and variance parameters to construct the distribution
On the other hand, the uncomputable distribution
Therefore, it is assumed that this KL divergence will not become 0, but the complex
This technique of approximating an uncomputable distribution with a simple distribution such as a normal distribution is called "Variational Approximation."
I heard that the 'V' in VAE comes from this 'Variational.'
Analysis of the ELBO
Next, let's look at each term of the ELBO in detail. For clarity, the ELBO is presented below.
Considering the First Term
ELBO's first term is as follows:
This is the expectation of the conditional log-likelihood of the ground truth image data
In deep learning contexts, an expectation can be viewed as the average of results obtained from a large amount of data. If we approximate the expectation with a sample size of 1 and assume
Where:
-
is thex_i -th ground truth image data from the dataseti .D -
is the latent representation sampled from a normal distribution formed by the mean and variance parameters obtained from the encoder with parametersz_i and input\psi .x_i -
is the output image data from the decoder with parameters\hat{x_i} and input\theta .z_i -
is the probability of the ground truth image data\mathcal{N}(x_i; \hat{x_i}, I) in a normal distribution with meanx_i and variance\hat{x_i} .I
Therefore, if you want to maximize the first term of the ELBO, you just need to maximize
First, let's look at the probability density function of a multivariate normal distribution. Considering a general form where the variance-covariance matrix is
Here,
Now, let's consider the specific case where the covariance matrix is the identity matrix, which is our current problem setting. This results in the following:
In solving the optimization problem, the first term is a constant and can be ignored. Therefore:
Maximizing the first term of the ELBO boils down to the problem of minimizing the squared error between the ground truth image data and the decoder output.
Considering the Second Term
The second term of the ELBO is as follows:
In maximizing the ELBO, since the second term is negative, we need to bring the non-negative KL divergence closer to 0.
The distributions to be brought closer together are the encoder's probability distribution
As in Bayesian statistics, the prior distribution is one we can define as we see fit. However, setting an arbitrary distribution will degrade accuracy, so we must specify a reasonably valid distribution.
Additionally, a VAE is required to function as a generative AI. This means that if we sample a latent variable
Thus, VAEs set a standard normal distribution with mean
Why the Standard Normal Distribution is Used as the Prior
There are several reasons why the standard normal distribution with mean
First, the normal distribution is the distribution that appears as a result of solving the entropy maximization problem using Lagrange multipliers; under the condition of mean
Second, it is a distribution that can be easily analyzed. Among the few distributions where the KL divergence can be solved analytically, the normal distribution is one of them. Therefore, choosing a normal distribution for calculating KL divergence is inevitable. Furthermore, using a mean of
Third, it provides regularization for the latent space. By assuming a standard normal distribution as the prior for the latent variable
Analytically Solving the KL Divergence
Next, let's solve the KL divergence of the second term of the ELBO analytically. It is known that the KL divergence between two normal distributions is expressed by the following formula (I hope you can accept this as a given):
Where the two normal distributions are assumed to be:
Applying this to our specific KL divergence, we get:
Here,
The Final Expression
In the end, the ELBO reduces to the following expression:
From VAE to Diffusion Models
Well done so far.
By now, I believe you have understood everything from maximizing the log-likelihood to maximizing the ELBO, which are the objective functions of VAEs.
However, the discussion so far has focused on VAEs.
From here on, I would like to consider diffusion models.
But diffusion models are actually simpler.
This is because there are no learnable parameters in the Encoder.
Let me restate the log-likelihood transformation for VAEs below.
The structure of a diffusion model is basically very similar to what is called a multi-layered VAE.
The difference is that there are no parameters in the Encoder, and the next stage of latent representation is obtained by repeatedly performing the same process on the current latent representation.
(For more details, please see this article.)
In other words, you can think of it as a VAE log-likelihood transformation without the Encoder parameters.
And if there are no parameters, they can be ignored in the optimization problem.
I will omit the detailed explanation as I am exhausted, but finally, only the squared error term between the generated image and the ground truth image remains. Therefore, in diffusion models, the objective function becomes simply the minimization of the squared error between the generated image and the ground truth image.
It's quite straightforward.
Summary
Thank you for reading this far!
Understanding these discussions will help you understand the objective functions of generative AI systems, which I believe will make papers easier to read.
I hope this is helpful to everyone!
Discussion