I am a little bit confused with the chain rule of matrix derivatives. For example, let
$f(X) = \text{tr}([\log(W X W^\top + B)]^2)$,
where $\log(X)$ is the matrix logrithm of matrix $X$, $X$ is a $m\times m$ symmetric positive definite (SPD) matrix, $B$ is a $n \times n$ SPD matrix ($n>m$), and $W\in \mathbb{R}^{n\times m}$ is a rectangle matrix. If I use the chain rule, I should have
$\frac{\partial f}{\partial X} = 2\log(W X W^\top + B) W^\top(W^\top X W + B)^{-1}W$.
However, the dimensions of $\log(W X W^\top + B)$ and $W^\top(W^\top X W + B)^{-1}W$ are $n \times n$ and $m \times m$ respectively. So there must be something wrong with my my derivations, but I don't know where is it.
Any comments? Thanks a lot!
* Addition *
Let $Z=(W X W^\top + B)S$.
What if $f(X) = \text{tr}((\log(Z))^\top\log(Z))$, where $S$ is also SPD matrix. Do we have
$\frac{\partial f}{\partial X} = 2W^\top \log[(W X W^\top + B)S] (W^\top X W + B)^{-1}W$,
or
$\frac{\partial f}{\partial X} = 2W^\top(W^\top X W + B)^{-1}S^{-1} \log[(W X W^\top + B)S] SW$?
Using the notations in this post, my solution is:
Define $Z=(W X W^\top + B)S$, and $\phi=\text{tr}([\log(Z)]^2)$. Then we have
$d\phi = 2\log(Z)\cdot Z^{-\top}:dZ=2\log(Z)\cdot Z^{-\top}: WdXW^\top S = 2W^\top\log(Z)\cdot Z^{-\top}SW: dX$.
Therefore, we have
$\frac{\partial \phi}{\partial X} = 2W^\top\log(Z)\cdot Z^{-\top}SW=2W^\top \log[(W X W^\top + B)S] (W^\top X W + B)^{-1}W$.
Is it correct?
Let $Z=WXW^T+B$; it's a symmetric $>0$ matrix. Since $\log(Z)$ and $Z^{-1}$ commute, the derivative is
$Df_X:K\in M_{m,m}\rightarrow 2tr(\log(Z)Z^{-1}WKW^T)=2tr(W^T\log(Z)Z^{-1}WK)$.
Then the gradient is
$\nabla(f)(X)=2W^TZ^{-1}\log(Z)W\in M_{m,m}$.
When I am writing, I see that greg obtains the same result.
EDIT. Comment on the addition by @user3138073 . The answer is no but you will have trouble understanding why...
Assume that $U(t)$ is a function of $t\in \mathbb{R}$ and let $f(t)=tr((\log(U))^2)$; then $f'(t)=2tr((\log(U))'\log(U))=2tr(\log(U)U^{-1}U')$; indeed, behind, there is a series, and thanks to the trace, and because $\log(U)$ is a polynomial in $U$ (it's true when $U$ has no $<0$ eigenvalues), we can put $U'$ on the right side of the trace and obtain the series which gives $U^{-1}$ (it's absolutely not obvious!).
In a second time, you choose $g(t)=tr((\log(U)S)^2)$. Then $g'(t)=2tr((\log(U))'S\log(U)S)$. Unfortunately, $(\log(U))'=U^{-1}U'$ is absolutely false (it's much more complicated than that!). If you put $U'$ on the right side of the trace, then you break the series (cf. above) because $S,U$ don't commute.
In other words, $tr(U^2U'U^3\log(U))=tr(U^5U'\log(U))$ but $tr(U^2U'U^3S\log(U)S)\not= tr(U^5U'S\log(U)S)$.