Proving Convexity of KL Divergence w.r.t. first argument?

1.1k Views Asked by At

I'd like to show that the KL divergence is convex w.r.t. its first argument, where the KL divergence is defined as

$KL(q|p) = \sum_x q(x) \log \frac{q(x)}{p(x)}$

This question suggests that I can show convexity using the log sum inequality, but thus far, I've only been able to show that $KL(\lambda q + (1-\lambda) r | p) \geq 0$. Can someone help me?

1

There are 1 best solutions below

0
On

Let me first rewrite $\mathrm{KL}$ slightly more conveniently: $$ \mathrm{KL}(q\|p) = \sum q(x) \log q(x) - \sum q(x) \log p(x).$$ The second of these terms is linear in $q$, so you only need to argue that $\varphi(q) := \sum q(x) \log q(x)$ is convex. This follows because the function $u \mapsto u \log u$ with $0 \log 0 := 0$ is convex on $\mathbb{R}_{\ge 0}$. One way to show this is the log-sum inequality: for any $u_1, u_2$ and $\lambda \in (0,1),$ take $a_1 = \lambda u_1, a_2 = (1-\lambda) u_2, b_1 = \lambda, b_2 = (1-\lambda)$, in which case the log-sum inequality tells us that $$ \lambda u_1\log u_1 + (1-\lambda) u_2 \log u_2 \ge (\lambda u_1 + (1-\lambda) u_2) \log \frac{(\lambda u_1 + (1-\lambda) u_2)}{(\lambda + 1 - \lambda)}. $$

Now we can use this inequality term by term in the sum in $\varphi$. Let $\lambda \in (0,1)$. Then $$ \varphi(\lambda q_1 + (1-\lambda)q_2) = \sum (\lambda q_1(x) + (1-\lambda)q_2(x)) \log (\lambda q_1(x) + (1-\lambda)q_2(x)) \\ \le \sum \lambda q_1(x) \log q_1(x) + (1-\lambda) q_2(x) \log q_2(x) \\ = \lambda \sum q_1(x) \log q_1(x) + (1-\lambda) \sum q_2(x) \log q_2(x) \\ = \lambda \varphi(q_1) + (1-\lambda) \varphi(q_2). $$

(Formally, for each $x$, we're using the inequality above with $u_1 = q_1(x)$ and $u_2 = q_2(x)$.)

In fact pretty much the same argument applies to directly arguing the joint convexity of $\mathrm{KL}$ in $(p,q)$ - set $b_1 = \lambda v_1$ and $b_2 = (1-\lambda) v_2$ to show the joint convexity of $(u,v) \mapsto u \log \frac{u}{v},$ and then use the same term-by-term approach.