chain rule for the trace of matrix logrithms

483 Views Asked by At

I am a little bit confused with the chain rule of matrix derivatives. For example, let

$f(X) = \text{tr}([\log(W X W^\top + B)]^2)$,

where $\log(X)$ is the matrix logrithm of matrix $X$, $X$ is a $m\times m$ symmetric positive definite (SPD) matrix, $B$ is a $n \times n$ SPD matrix ($n>m$), and $W\in \mathbb{R}^{n\times m}$ is a rectangle matrix. If I use the chain rule, I should have

$\frac{\partial f}{\partial X} = 2\log(W X W^\top + B) W^\top(W^\top X W + B)^{-1}W$.

However, the dimensions of $\log(W X W^\top + B)$ and $W^\top(W^\top X W + B)^{-1}W$ are $n \times n$ and $m \times m$ respectively. So there must be something wrong with my my derivations, but I don't know where is it.

Any comments? Thanks a lot!

* Addition *

Let $Z=(W X W^\top + B)S$.

What if $f(X) = \text{tr}((\log(Z))^\top\log(Z))$, where $S$ is also SPD matrix. Do we have

$\frac{\partial f}{\partial X} = 2W^\top \log[(W X W^\top + B)S] (W^\top X W + B)^{-1}W$,

or

$\frac{\partial f}{\partial X} = 2W^\top(W^\top X W + B)^{-1}S^{-1} \log[(W X W^\top + B)S] SW$?

Using the notations in this post, my solution is:

Define $Z=(W X W^\top + B)S$, and $\phi=\text{tr}([\log(Z)]^2)$. Then we have

$d\phi = 2\log(Z)\cdot Z^{-\top}:dZ=2\log(Z)\cdot Z^{-\top}: WdXW^\top S = 2W^\top\log(Z)\cdot Z^{-\top}SW: dX$.

Therefore, we have

$\frac{\partial \phi}{\partial X} = 2W^\top\log(Z)\cdot Z^{-\top}SW=2W^\top \log[(W X W^\top + B)S] (W^\top X W + B)^{-1}W$.

Is it correct?

2

There are 2 best solutions below

4
On BEST ANSWER

Let $Z=WXW^T+B$; it's a symmetric $>0$ matrix. Since $\log(Z)$ and $Z^{-1}$ commute, the derivative is

$Df_X:K\in M_{m,m}\rightarrow 2tr(\log(Z)Z^{-1}WKW^T)=2tr(W^T\log(Z)Z^{-1}WK)$.

Then the gradient is

$\nabla(f)(X)=2W^TZ^{-1}\log(Z)W\in M_{m,m}$.

When I am writing, I see that greg obtains the same result.

EDIT. Comment on the addition by @user3138073 . The answer is no but you will have trouble understanding why...

Assume that $U(t)$ is a function of $t\in \mathbb{R}$ and let $f(t)=tr((\log(U))^2)$; then $f'(t)=2tr((\log(U))'\log(U))=2tr(\log(U)U^{-1}U')$; indeed, behind, there is a series, and thanks to the trace, and because $\log(U)$ is a polynomial in $U$ (it's true when $U$ has no $<0$ eigenvalues), we can put $U'$ on the right side of the trace and obtain the series which gives $U^{-1}$ (it's absolutely not obvious!).

In a second time, you choose $g(t)=tr((\log(U)S)^2)$. Then $g'(t)=2tr((\log(U))'S\log(U)S)$. Unfortunately, $(\log(U))'=U^{-1}U'$ is absolutely false (it's much more complicated than that!). If you put $U'$ on the right side of the trace, then you break the series (cf. above) because $S,U$ don't commute.

In other words, $tr(U^2U'U^3\log(U))=tr(U^5U'\log(U))$ but $tr(U^2U'U^3S\log(U)S)\not= tr(U^5U'S\log(U)S)$.

7
On

To answer your second question, consider this calculation $$\eqalign{ Z &= WXW^TS+BS \cr L &= \log Z \cr \phi &= {\rm tr\,}L^2 \cr d\phi &= (2LZ^{-1})^T:dZ \cr &= 2L^TZ^{-T}:W\,dX\,W^TS \cr &= 2W^TL^TZ^{-T}S^TW:dX \cr \frac{\partial\phi}{\partial X} &= 2W^TL^TZ^{-T}S^TW \cr }$$ Your first question is a special case of your second, in which $S$ equals the identity matrix. This choice of $S$ makes $(L,Z)$ symmetric, so you can omit the transposes on those terms.