Got stuck at trying to figure out what the single shot at inference for Variational Autoencoder should be

32 Views Asked by At

Let's say you have an already trained Variational Autoencoder where the parameters are $\phi, \theta$ for the recognition and generative models respectively. Let's also assume you have the following Gaussian factorization scheme: $ \newcommand{\Norm}{\mathcal{N}} \newcommand{\pz}{p(z)} \newcommand{\ptheta}{p_{\theta}} \newcommand{\encoding}{q_{\phi}(z \g x)} \newcommand{\decoding}{p_{\theta}(x \g z)} \newcommand{\g}{\,|\,} \newcommand{\qphi}{q_{\phi}} \newcommand{\qphizgivenx}[2]{\qphi(#1 \g #2)} $ $$ \begin{split} \pz &= \Norm(0,I)\\ \decoding &= \Norm(\mu_{\theta}(z),\sigma^2I) \\ \encoding &=\Norm(\mu_{\phi}(x),diag(\sigma^2_{\phi}(x))) \end{split} $$ where $x \in \mathbb{R}^d,z \in \mathbb{R}^r, d>>r$, and $\mu_\phi, \sigma^2_\phi$ functions are neural network encoders while $\mu_{\theta}$ is a neural network decoder.

Now, according to Kingma & Welling 2019 p. 30 $\log \ptheta(x)$ can be approximated by a Monte Carlo sampling procedure with $L$ iterations: $$ \log \ptheta(x) \approx \log \frac{1}{L} \sum_{l=1}^{L}\frac{\ptheta(x,z^{(l)})}{\qphizgivenx{z^{(l)}}{x}} $$

Here, $z^{(l)}$ are sampled from $\qphizgivenx{z^{(l)}}{x}$. This operation is too costly for me so I want to do this in one shot. My question is, what would be the best $z^{(l)}$ to do this? Approximate answers are fine as long as there is some sound backing to it.

I tried a shot at this and here I am:

$$ \begin{split} \log \ptheta(x) & \approx \log \sum_{l=1}^{L} \exp\{\log\ptheta(x,z^{(l)}) - \log\qphizgivenx{z^{(l)}}{x}\} - \log L \\ & \leq \max_{l \in \{1\dots L\}}\{\log\ptheta(x,z^{(l)}) - \log\qphizgivenx{z^{(l)}}{x}\} +\log L - \log L \\ & = \max_{l \in \{1\dots L\}}\{\log\ptheta(x,z^{(l)}) - \log\qphizgivenx{z^{(l)}}{x}\} \\ & \leq \lim_{L\to\infty}\max_{l \in \{1\dots L\}}\{\log\ptheta(x,z^{(l)}) - \log\qphizgivenx{z^{(l)}}{x}\} \\ & = \max_{z^{(l)}\in \mathbb{R}^r}\log\ptheta(x,z^{(l)}) - \log\qphizgivenx{z^{(l)}}{x} \\ & = \max_{z^{(l)}\in \mathbb{R}^r}\log\ptheta(x\g z^{(l)})+ \log p(z^{(l)}) - \log\qphizgivenx{z^{(l)}}{x} \end{split} $$

The above follows from the definition of LogSumExp and the final step is basically the factorization of joint distribution. I get an upper bound this way, my hunch is that this will be maximized at $z^{(l)} = \mu_\phi(x)$ and I numerically demonstrated this over a few examples by getting large $L$ samples. I believe it's due to the fact that $d>>r$ and we rather zeroize $\qphizgivenx{z^{(l)}}{x}$ by doing so but I'm not entirely sure.

Any help is highly appreciated!