Is a Forward KL Divergence assumed in the original Diffusion Model Objective and if so, why?

35 Views Asked by At

I'm currently getting into Diffusion Models and Variational Inference in general. I've studied Lilian Weng's Blog on VAEs , where the importance of using the Reverse KL

$$D_{KL}(q_{\Phi}(z|x)||p_{\Theta}(z|x))$$

(with $q_{\Theta}(z|x)$ the estimated posterior and $p_{\Phi}(z|x)$ the true posterior) for fitting the estimated posterior is highlighted. My understanding of the reason for using the Reverse KL instead of the Forward KL is that we aim to "squeeze" the estimated posterior under the true posterior (Reverse KL), instead of trying to make the estimated posterior cover the whole true posterior space which yields a poor fit for a complicated true posterior (Forward KL) (Graphical explanation here and in the blog linked above). This Reverse KL formulation yields the ELBO:

$$\boldsymbol{ - \log p_{\Theta}(x) \leq - E_{z \sim q} [\log p_{\Theta}(x|z)] + D_{KL} (q_{\Phi}(z|x)||{p_{\Theta}(z)})} \\ = - \log p_{\Theta}(x) + D_{KL}(q_{\Phi}(z|x) || p_{\Theta}(z|x))$$

with the learnable posterior $q_{\Theta}$ in the numerator of all KL Divergence terms, i.e., formulated as Reverse KL terms.

Now in the original paper on Diffusion Models, the objective is formulated as:

$$ - \log p_{\Theta}(x) \leq - \log p_{\Theta}(x) + D_{KL}(q(x_{1:T}|x_0) || p_{\Theta}(x_{1:T}|x_0))$$

whereby $q(x_{1:T}|x_0)$ is the true and $p_{\Theta}(x_{1:T}|x_0)$ the estimated conditional probability of the reverse diffusion process. Note hereby that $q$ is a fixed distribution and $\boldsymbol{p}$ is learnable, so the opposite to how it is formulated in the VAE ELBO formulation above. Thus, following my understand, the logic in the Diffusion objective is reversed, as the learnable distribution is in the denominator, rendering the KL term a Forward Divergence.

Is my understanding about the assumed Forward Divergence in the Diffusion Objective correct? And if so, why would would we use a Forward KL Formulation in Diffusion Models as opposed to a Reverse KL in VAEs, as Reverse seems to allow the estimated distribution to better fit to (parts of) the true distribution?

Thanks a lot!