Derivation of M step in EM model with GMM

193 Views Asked by At

I follow these notes, specifically page 9, about the EM algorithm. We assume observations $\{x^{1},\ldots,x^{n}\}$ come from a mixture of $k$ gaussians with parameters $\{\mu_j,\Sigma_j\}_{j=1}^{k}$, with a latent variable model, where for each observation $x^i$ there's a r.v. $z^i\sim\text{Mult}(\phi)$. We iteratre through two steps: in the first step, we compute $$w^{i}_j:=p(z^{i}=j|x^i;\phi,\mu,\Sigma)$$ In the second step we take derivative of the following function, called ELBO, w.r.t. $\phi, \mu,\Sigma$ $$\sum_{i=1}^{n}w_j^{i}*\log(\frac{\mathcal{N(x^i;\mu_j,\Sigma_j)*\phi_j}}{w^i_j})$$

My problem is that when taking the derivative, we treat the $w^i_j$ as constant, and not as dependent on the parameters $\phi, \mu,\Sigma$. Here's a calculation. So for example, $\frac{\partial}{\partial\mu_j}w^i_j\cdot\mu_j=w^i_j$, where I understand that $w^i_j$ is a function of $\phi, \mu,\Sigma$, and hence we must use the chain rule and get a more complicated result. The statement that "we computed $w^i_j$ at a previous step and therefore it is constant" seems wrong, since we might as well have plugged in the definition of $w^i_j$ into the function we're taking derivative of, instead of precomputing it. How does breaking up the computation into two steps allows us to take simpler derivatives?

Note - my problem isn't specific at all to EM, but it arose in this context, and can be stated in simpler terms.

1

There are 1 best solutions below

2
On BEST ANSWER

The notes are correct - the posteriors/conditional distributions $w_j^i$ can be treated as constants in the M-step, and there is no need for the chain rule in what you have outlined above.

The problem you have outlined can be solved by focusing on how EM works algorithmically. To emphasise this, I'm going to amend what you've written to make the role of each iteration more explicit.

Start with current estimates of the the parameters $\theta^{(t)} = \{ \phi^{(t)}, \mu^{(t)}, \Sigma^{(t)} \}$, which in the case $t=0$ is an initial guess.

E-step: Compute

$$w_j^{(i)} = Q_i(z^{(i)} = j) = P(z^{(i)} = j | x^{(i)}; \phi^{(t)}, \mu^{(t)}, \Sigma^{(t)})$$

Notice that in the E-step you are the "computing the probability of $z^{(i)}$ taking the value $j$ under the distribution $Q_i$", conditional on observed values $x^{(i)}$ and fixing the values of the parameters at current parameter estimates $\phi^{(t)}, \mu^{(t)}, \Sigma^{(t)}$.

M-step: Maximise the following with respect to $\phi, \mu, \Sigma$ to yield new estimates $\theta^{(t+1)} = \{\phi^{(t+1)}, \mu^{(t+1)}, \Sigma^{(t+1)} \}$:

$$\begin{align} \phi^{(t+1)}, \mu^{(t+1)}, \Sigma^{(t+1)} &= \underset{\phi, \mu, \Sigma}{\text{argmax}} \sum^n_{i=1} \sum^k_{j=1} w_j^{(i)} \log \frac{\mathcal{N}(x^{(i)}; \mu_j, \Sigma_j) \cdot \phi_j}{w_j^{(i)}} \\ &= \underset{\phi, \mu, \Sigma}{\text{argmax}} \sum^n_{i=1} \sum^k_{j=1} P(z^{(i)} = j | x^{(i)}; \phi^{(t)}, \mu^{(t)}, \Sigma^{(t)}) \log \frac{\mathcal{N}(x^{(i)}; \mu_j, \Sigma_j) \cdot \phi_j}{P(z^{(i)} = j | x^{(i)}; \phi^{(t)}, \mu^{(t)}, \Sigma^{(t)})} \end{align}$$

In this step, you now allow the parameters $\phi, \mu, \Sigma$ to vary, and you choose them to maximise the the above quantity. However, the parameters are only allowed to freely vary everywhere (i.e. in the numerator of the fraction) other than in the conditional distribution $w^i_j$, which is held fixed at the values of the parameters $\phi = \phi^{(t)}, \mu = \mu^{(t)}, \Sigma = \Sigma^{(t)}$ from the previous step (and also computed using your observations $x^{(i)}$). Therefore $w^i_j$ does not contain $\phi, \mu, \Sigma$ as arguments, and hence there is no need for the chain rule.

Only the numerator of the fraction, that is $\mathcal{N}(x^{(i)}; \mu_j, \Sigma_j) \cdot \phi_j$, which is evaluated at $x^{(i)}$ will contain $\phi, \mu, \Sigma$ as arguments, which are treated as unknown variables, and need to be estimated. In order to set $\phi, \mu, \Sigma$ for maximisation, you now compute the derivative with respect to $\phi, \mu, \Sigma$, set to 0 and solve to yield $\theta^{(t+1)} = \{\phi^{(t+1)}, \mu^{(t+1)}, \Sigma^{(t+1)} \}$.