Soft-EM: E-step for fitting mixed linear regression model

80 Views Asked by At

I want to derive the formulas for the soft EM algorithm for the following model $P[y_i | x_i, \pi_{1,\dots,m}, a_{1,\dots,m}] = \sum_{j=1}^m \pi_j \frac{1}{\sqrt{2\pi}\sigma} exp(-\frac{(a_j^T x_i - y_i)^2}{2 \sigma^2})$ over the data $(x_1,y_1),\dots,(x_n,y_n) \in \mathbb{R}^d \times \mathbb{R}$, hidden labels $z_1,\dots, z_n$ and parameters $\pi_{1,\dots,m} \in \mathbb{R}$, $a_{1,\dots,m} \in \mathbb{R}^d$, which is a mixture of linear regression models. The E step makes sense to me, but I'm not able to do the ML estimation for the parameter $a$. It is not clear to me how I can bring the log likelihood $$\log L((x_1, y_1),\dots,(x_n, y_n) | a) = \log \Pi_{i=1}^n P[y_i | x_i, a] = \sum_{i=1}^n \log(\sum_{j=1}^m \gamma_j(x_i, y_i) \frac{1}{\sqrt{2\pi}\sigma} exp(-\frac{(a_j^T x_i - y_i)^2}{2 \sigma^2}))$$ where $$\gamma_j(x_i, y_i) = \frac{P[y_i|x_i,z_i=j]P[z_i=j]}{\sum_{k=1}^j P[y_i|x_i,z_i=k]P[z_i=k]}$$ into a form where one can optimize it as a weighted linear regression problem as is described in this paper.

1

There are 1 best solutions below

0
On

So the variables are:

  • $\pi_{j}$, the probability of any data point being explainable by the $j$th regression model.
  • $\mathbf a_j \in \mathbb R^d$, the set of coefficients for the $j$th regression model.
  • $z_{ij}$, an indicator variable that takes value $1$ if the $i$th datapoint is explainable by the $j$th regression model, or $0$ otherwise. (So for any given $i$, precisely one of the $z_{ij}$'s is equal to $1$.)
  • $\mathbf x_i \in \mathbb R^d$, the $i$th feature vector.
  • $y_i$, the $i$th prediction target.

To make the notation clearer, let's say that $i$ runs from $1$ to $I$, and $j$ runs from $1$ to $J$. So $I$ is the number of datapoints, and $J$ is the number of regression models.

The full likelihood function (with indices suppressed on the left-hand side) is $$ P\left(y, z\mid \mathbf x, \pi, \mathbf a \right) = \prod_{i=1}^I \left( \sum_{j=1}^J \mathbf 1_{z_{ij} = 1}\times \pi_{j} \times \mathcal N \left(y_i \mid \mathbf a_j^T \mathbf x_i, \sigma^2 \right) \right),$$ where $$ \mathcal N\left( y \mid \mu, \sigma^2 \right) := \frac{1}{\sqrt{2\pi \sigma^2}} \exp \left( - \frac{(y-\mu)^2}{2\sigma^2} \right)$$ and $$ \mathbf 1_{z= 1} := \begin{cases} 1 & {\rm if \ } z= 1 \\ 0 & {\rm if \ } z= 0\end{cases}.$$

Our goal is to find the values of $\pi$ and $\mathbf a$ that maximise the marginal likelihood, $$ P\left(y, \mid \mathbf x, \pi, \mathbf a \right) = \sum_{z} P\left(y, z \mid \mathbf x, \pi, \mathbf a \right) . $$

To do this, we use the expectation-maximisation algorithm, which is an iterative algorithm. Each step in the iteration consists of an E-step and an M-step.

E-step. Given $(\pi^{\rm old}, \mathbf a^{\rm old})$ from the previous iteration, calculate the conditional distribution $$P\left(z \mid y, \mathbf x, \pi^{\rm old}, \mathbf a^{\rm old} \right).$$

This is the bit you know how to do. By Bayes' theorem, the answer is $$ \gamma_{ij} := P \left( z_{ij} = 1 \mid y, \mathbf x, \pi^{\rm old}, \mathbf a^{\rm old} \right) = \frac{\pi_j^{\rm old} \times \mathcal N \left( y_i \mid (\mathbf a_j^{\rm old})^T \mathbf x_i, \sigma^2 \right) }{ \sum_{j' = 1}^J \pi_{j'}^{\rm old} \times \mathcal N \left( y_i \mid (\mathbf a_{j'}^{\rm old})^T \mathbf x_i, \sigma^2 \right) }.$$

M-step. Find $(\pi^{\rm new}, \mathbf a^{\rm new})$, defined as the solutions to the maximisation problem, $$ \pi^{\rm new}, \mathbf a^{\rm new} = {\rm argmax}_{\pi, \mathbf a} \left( \mathbb E_{z \sim P\left(z \mid y, \mathbf x, \pi^{\rm old}, \mathbf a^{\rm old} \right)}\left[ \log \left( P\left(y, z \mid \mathbf x, \pi, \mathbf a \right) \right) \right] \right).$$ (These $(\pi^{\rm new}, \mathbf a^{\rm new})$ will be carried forward to the E-step in the next iteration, and will become the $(\pi^{\rm old}, \mathbf a^{\rm old})$ for the E-step in the next iteration.)

We can express the expectation in terms of the $\gamma_{ij}$'s computed in the preceding E-step.

\begin{multline} \mathbb E_{z \sim P\left(z \mid y, \mathbf x, \pi^{\rm old}, \mathbf a^{\rm old} \right)}\left[ \log \left( P\left(y, z \mid \mathbf x, \pi, \mathbf a \right) \right) \right] \\ = \sum_{j=1}^J \left( \sum_{i=1}^I \gamma_{ij} \log \pi_j \right) + \sum_{j = 1}^J \left(\sum_{i=1}^I \gamma_{ij} \log \mathcal N \left(y_i \mid \mathbf a_j^T \mathbf x_i, \sigma^2 \right) \right).\end{multline}

(Notice how the $\gamma_{ij}$'s sit outside the logarithms in my expression!)

For each $j$, the optimal $\pi_j$ is given by $$ \pi_j^{\rm new} = \frac{\sum_{i=1}^I \gamma_{ij}}{\sum_{j' = 1}\sum_{i=1}^I \gamma_{ij'}}.$$

For each $j$, the optimal $\mathbf a_j$ is given by \begin{align} \mathbf a_j^{\rm new} & = {\rm argmax}_{\mathbf a_j} \left( \sum_{i = 1}^I \gamma_{ij} \log \mathcal N \left( y_i \mid \mathbf a_j^T \mathbf x_i , \sigma^2 \right) \right) \\ &= {\rm argmin}_{\mathbf a_j} \left( \sum_{i = 1}^I \gamma_{ij} \left( y_i - \mathbf a_j^T \mathbf x_i \right)^2 \right).\end{align}

This is a standard weighted regression problem, where the $i$th datapoint is assigned a weighting of $\gamma_{ij}$.

It can be proved that $$ P\left( y \mid \mathbf x, \pi^{\rm new}, \mathbf a^{\rm new} \right) \geq P\left( y \mid \mathbf x, \pi^{\rm old}, \mathbf a^{\rm old} \right).$$ (The proof is by no means obvious, but it's a standard proof that applies to all examples of expectation-maximisation.)

Hence if you follow this iterative procedure for many steps, you generate a sequence, $$ (\pi^{old}, \mathbf a^{old}), \ \ (\pi^{\rm new}, \mathbf a^{\rm new}), \ \ (\pi^{\rm newer}, \mathbf a^{\rm newer}), \ \ (\pi^{\rm even \ newer}, \mathbf a^{\rm even \ newer}), \ \ \dots$$ where the marginal likelihood $P\left( y \mid \mathbf x, \pi, \mathbf a \right)$ increases on each step. This is why the iterative procedure converges to the $(\pi, \mathbf a)$ that maximise the marginal likelihood $P\left( y \mid \mathbf x, \pi, \mathbf a \right)$.