Minimizing the sum of KL divergences

3.1k Views Asked by At

Given a list of probability distributions $q_i$, what distribution $p$ minimizes the sum of KL divergences (if they exist) to and from each of them? That is, how do I determine

$$\operatorname*{argmin}_p \sum_i D_\text{KL}(p \mathbin{\Vert} q_i)$$

and

$$\operatorname*{argmin}_p \sum_i D_\text{KL}(q_i \mathbin{\Vert} p)$$

I recall reading somewhere that, in the Jensen-Shannon divergence, \begin{align*} D_\text{JS}(p, q) &= \frac{D_\text{KL}(p \mathbin{\Vert} r) + D_\text{KL}(q \mathbin{\Vert} r)}{2} \\ r = &\frac{p + q}{2} \end{align*}

The midpoint distribution $r$ is precisely $$r = \operatorname*{argmin}_s \frac{D_\text{KL}(p \mathbin{\Vert} s) + D_\text{KL}(q \mathbin{\Vert} s)}{2}$$

I can't find a reference for this, however.

1

There are 1 best solutions below

3
On BEST ANSWER

A rather brute-force approach, using Lagrange multipliers. Defining the objetive functional

$$g(p_x)=\sum_{i=1}^k D(p||q^{(i)})=\sum_{i=1}^k \sum_{x} p_x \log\left( \frac{p_x}{q_x^{(i)}} \right) \tag{1}$$

and the restriction $\sum_x p_x = 1$ we get the critical point at

$$ n+\sum_{i=1}^k \log\left( \frac{p_x}{q_x^{(i)}} \right) +\lambda =0 \tag{2}$$

Rearraging we get

$$ p_x = \gamma \left(\prod_{i=1}^k q_x^{(i)}\right)^{1/k} \tag{3}$$

where $\gamma$ is the normalizing constant. Hence the critical distribution is the normalized geometric mean of the given $q_i$ distributions.

Because the KL divergence is convex in both arguments, this critical point must be a global minimum.


Or, simpler and better: changing the sum order in $(1)$ :

$$g(p_x)= \sum_{x} p_x \log \prod_{i=1}^k \left( \frac{p_x}{q_x^{(i)}}\right) =k \sum_{x} p_x \log \left( \frac{p_x}{\overline{q_x}}\right) \tag{4}$$

where $\overline{q_x}$ is the geometric mean, as in $(3)$. Notice, however that $\overline{q_x}$ is not in general a probability function. Defining the normalization constant $\gamma=1/\sum \overline{q_x}$ we get

$$g(p_x)=k \sum_{x} p_x \log \left( \frac{\gamma p_x}{ \gamma \overline{q_x}}\right)=k \sum_{x} p_x \log \left( \frac{ p_x}{ \gamma \overline{q_x}}\right) + k \log(\gamma) =\\=k D(p||\gamma \overline{q}) + k \log(\gamma) \tag{5}$$

The variable term is the first, which is a true KL-divergence, and is minimized (at zero) by $p=\gamma \overline{q}$, in agreement with $(3)$. The residual term gives the value of this minimum.

BTW: that $\gamma\ge 1$ (with equality only for all $q_i$ identical) is easily proved by GM-AM inequality.


Added: Regarding the other part, it's also simple:

$$\begin{align} \sum_{i=1}^k D(q^{(i)}||p)&=\sum_{i=1}^k \sum_{x} q_x^{(i)} \log\left( \frac{q_x^{(i)}}{p_x} \right)\\ &=\sum_{x} \sum_{i=1}^k q_x^{(i)} \left(\log q_x^{(i)} - \log p_x \right)\\ &=\sum_{i=1}^k \sum_{x} q_x^{(i)} \log q_x^{(i)} - \sum_{x} \log p_x \sum_{i=1}^k q_x^{(i)} \\ &= -\sum _{i=1}^k H(q_i) - k \sum_{x} \tilde {q_x} \log p_x \\ &= -\sum _{i=1}^k H(q_i) + k H(\tilde {q} ) + k D(\tilde {q}|| p)\\ &= k\left(H(\tilde {q} ) - \frac{1}{k}\sum _{i=1}^k H(q_i) + D(\tilde {q}|| p) \right)\\ \end{align}$$

where $\tilde {q_x} = \frac{1}{k}\sum_{i=1}^k q_x^{(i)}$ (arithmetic mean of $q_i$) - which is a valid probability distribution. And, yes, this is minimized when $p=\tilde {q} $

BTW 2: $H(\tilde {q} ) - \frac{1}{k}\sum _{i=1}^k H(q_i)\ge 0$ because the entropy is convex. Equality only occurs when all $q_i$ are identical.