Derivative of $\operatorname{trace}(XWW^{T}X^{T})$ with respect to $W$

412 Views Asked by At

Compute

$$\frac{d}{dW}\operatorname{trace}(XWW^T X^T)$$

where $X$, $W$ are $n\times n$ real matrices.

2

There are 2 best solutions below

0
On BEST ANSWER

Let $f : \mathbb R^{n \times n} \to \mathbb R$ be defined by

$$f (\mathrm X) = \mbox{tr} (\mathrm A \mathrm X \mathrm X^T \mathrm A^T)$$

The directional derivative of $f$ in the direction of $\mathrm V$ at $\mathrm X$ is

$$\begin{array}{rl} D_{\mathrm V} f (\mathrm X) &= \mbox{tr} (\mathrm A \mathrm V \mathrm X^T \mathrm A^T) + \mbox{tr} (\mathrm A \mathrm X \mathrm V^T \mathrm A^T)\\ &= \mbox{tr} (\mathrm X^T \mathrm A^T \mathrm A \mathrm V) + \mbox{tr} (\mathrm V^T \mathrm A^T \mathrm A \mathrm X)\\ &= \mbox{tr} ((\mathrm A^T \mathrm A \mathrm X)^T \mathrm V) + \mbox{tr} (\mathrm V^T \mathrm A^T \mathrm A \mathrm X)\\ &= \langle \mathrm A^T \mathrm A \mathrm X, \mathrm V \rangle + \langle \mathrm V, \mathrm A^T \mathrm A \mathrm X \rangle\\ &= \langle 2 \mathrm A^T \mathrm A \mathrm X, \mathrm V \rangle\end{array}$$

Hence,

$$\nabla f (\mathrm X) = 2 \mathrm A^T \mathrm A \mathrm X$$


0
On

A different solution from the one I proposed using the Matrix Cookbook equation $(116)$ (if you are not too familiar with matrix calculus) involves taking these products, then writing them out using index notation:

$$V=WW^T$$ $$A=XV$$ $$B=AX^T$$

Hence: $$v_{ij}=\sum_kw_{ik}w^T_{kj}$$ $$a_{mj}=\sum_ix_{mi}v_{ij}$$ $$b_{mn}=\sum_ja_{mj}x^T_{jn}=\sum_{i,j,k}x_{mi}w_{ik}w_{kj}^Tx_{jn}^T$$

$$\operatorname{trace}(B)=\sum_{m}b_{mm}=\sum_{i,j,k,m}x_{mi}w_{ik}w_{kj}^Tx_{jm}^T$$

Let's suppose we want to find the element of index $(r,s)$ of the resulting derivative matrix. Only two elements of the trace contain the variable $w_{rs}$, namely: $$\sum_{j,m}x_{mr}w_{rs}w_{sj}^Tx_{jm}^T \qquad \text{and}\qquad \sum_{i,m}x_{mi}w_{is}w_{sr}^Tx_{rm}^T$$

All other terms vanish because they are independent of $w_{rs}$. These two sums are in fact the same, hence: $$\frac{d \operatorname{trace}(B)}{w_{rs}}=2\sum_{i,m}x^T_{rm}x_{mi}w_{is}=2\sum_i \left(\sum_m x_{rm}^Tx_{mi}\right)w_{is}$$

The sum over $m$ clearly represents an element of the matrix $X^TX$, so that the sum over $i$ is an element from the matrix $X^TXW$. Finally: $$\frac{d \operatorname{trace}(XWW^TX^T)}{dW}=2X^TXW$$