Gradient of linear scalar field $X \mapsto \operatorname{tr}(AXB)$

36.5k Views Asked by At

Could someone explain the following?

$$ \nabla_X \operatorname{tr}(AXB) = BA $$

I understand that

$$ {\rm d} \operatorname{tr}(AXB) = \operatorname{tr}(BA \; {\rm d} X) $$

but I don't quite understand how to move ${\rm d} X$ out of the trace.

3

There are 3 best solutions below

5
On

The notation is quite misleading (at least for me).

Hint:

Does it make sense that $$\frac{\partial}{\partial X_{mn}} \mathop{\rm tr} (A X B) = (B A)_{nm}?$$

More information: $$\frac{\partial}{\partial X_{mn}} \mathop{\rm tr} (A X B) = \frac{\partial}{\partial X_{mn}} \sum_{jkl} A_{jk} X_{kl} B_{lj} = \sum_{jkl} A_{jk} \delta_{km} \delta_{nl} B_{lj} = \sum_{j} A_{jm} B_{nj} =(B A)_{nm}. $$

2
On

Try expanding to linear order. This always eases the understanding:

$$\operatorname{tr}(A (X+dX)B)=A_{ij} (X_{jk}+dX_{jk})B_{ki}$$

where Einstein's summation rule is used. Substracting $\operatorname{tr}(AXB)$ you get

$$\begin{align} d\operatorname{tr}(AXB)&=\operatorname{tr}(A(X+dX)B)-\operatorname{tr}(AXB)\\&=A_{ij} dX_{jk}B_{ki}=\underbrace{B_{ki}A_{ij}}_{=(BA)_{kj}} \; dX_{jk} \end{align}$$

0
On

These are the main equations to remember:

  1. Let $\mathbf{A} \in \mathbb{R}^{n\times m}$, $\mathbf{X} \in \mathbb{R}^{m\times n}$. Then

\begin{equation} \frac{d}{d\mathbf{X}}\text{Tr}(\mathbf{AX}) = \frac{d}{d\mathbf{X}}\text{Tr}(\mathbf{XA}) = \mathbf{A}^T \end{equation}

  1. Let $\mathbf{A} \in \mathbb{R}^{n\times m}$, $\mathbf{X} \in \mathbb{R}^{n\times m}$. Then

\begin{equation} \frac{d}{d\mathbf{X}}\text{Tr}(\mathbf{AX^T}) = \frac{d}{d\mathbf{X}}\text{Tr}(\mathbf{X^TA}) = \mathbf{A} \end{equation}

Proof 1.

\begin{equation} \left[ \frac{d}{d\mathbf{X}} \text{Tr}(\mathbf{AX}) \right]_{i,j} = \frac{d}{dx_{i,j}} \text{Tr}(\mathbf{AX}) = \frac{d}{dx_{i,j}} \sum_{k,l} a_{k,l} x_{l,k} = a_{j,i} = \left[\mathbf{A}^T\right]_{i,j} \end{equation}

Proof 2

\begin{equation} \left[ \frac{d}{d\mathbf{X}} \text{Tr}(\mathbf{AX^T}) \right]_{i,j} = \frac{d}{dx_{i,j}} \text{Tr}(\mathbf{AX^T}) = \frac{d}{dx_{i,j}} \sum_{k,l} a_{k,l} x_{k,l} = a_{i,j} = \left[\mathbf{A}\right]_{i,j} \end{equation}

Once you have these, you can derivate crazy things like the following:

Example 1. Let $\mathbf{A} \in \mathbb{R}^{m\times m}$, $\mathbf{X} \in \mathbb{R}^{m\times n}$. Then

\begin{equation} \frac{d}{d\mathbf{X}} \text{Tr}(\mathbf{X}^T \mathbf{A} \mathbf{X}) = (\mathbf{A} + \mathbf{A}^T) \mathbf{X} \end{equation}

We can derive it as follows:

\begin{split} \frac{d}{d\mathbf{X}} \text{Tr}(\mathbf{X}^T \mathbf{A} \mathbf{X}) =& \frac{d}{d\mathbf{Y}} \text{Tr}(\mathbf{Y}^T \mathbf{A} \mathbf{X}) + \frac{d}{d\mathbf{Y}} \text{Tr}(\mathbf{X}^T \mathbf{A} \mathbf{Y})\\ =& \mathbf{A} \mathbf{X} + (\mathbf{X}^T \mathbf{A})^T\\ =& \mathbf{A} \mathbf{X} + \mathbf{A}^T \mathbf{X} = (\mathbf{A} + \mathbf{A}^T) \mathbf{X} \end{split}

Example 2. Consider now this example.

\begin{equation} f(\mathbf{X}) = \text{Tr}(\mathbf{X}^T \mathbf{A} \mathbf{X} \mathbf{B} \mathbf{X}^T \mathbf{C}) \end{equation}

where $\mathbf{X} \in \mathbb{R}^{n\times m}$, $\mathbf{A} \in \mathbb{R}^{n\times n}$, $\mathbf{B} \in \mathbb{R}^{m\times m}$, $\mathbf{C} \in \mathbb{R}^{n\times m}$.

\begin{equation} \begin{split} \frac{d}{d\mathbf{X}} f(\mathbf{X}) =& \frac{d}{d\mathbf{Y}} \text{Tr}(\mathbf{Y}^T \mathbf{A} \mathbf{X} \mathbf{B} \mathbf{X}^T \mathbf{C})\\ +& \frac{d}{d\mathbf{Y}} \text{Tr}(\mathbf{X}^T \mathbf{A} \mathbf{Y} \mathbf{B} \mathbf{X}^T \mathbf{C})\\ +& \frac{d}{d\mathbf{Y}} \text{Tr}(\mathbf{X}^T \mathbf{A} \mathbf{X} \mathbf{B} \mathbf{Y}^T \mathbf{C}) \end{split} \end{equation}

Calculating these:

\begin{split} \frac{d}{d\mathbf{Y}} \text{Tr}(\mathbf{Y}^T \mathbf{A} \mathbf{X} \mathbf{B} \mathbf{X}^T \mathbf{C}) = \mathbf{A} \mathbf{X} \mathbf{B} \mathbf{X}^T \mathbf{C} \end{split}

\begin{split} \frac{d}{d\mathbf{Y}} \text{Tr}(\mathbf{X}^T \mathbf{A} \mathbf{Y} \mathbf{B} \mathbf{X}^T \mathbf{C}) =& \frac{d}{d\mathbf{Y}} \text{Tr}(\mathbf{Y} \mathbf{B} \mathbf{X}^T \mathbf{C} \mathbf{X}^T \mathbf{A})\\ =& (\mathbf{B} \mathbf{X}^T \mathbf{C} \mathbf{X}^T \mathbf{A})^T\\ =& \mathbf{A}^T \mathbf{X} \mathbf{C}^T \mathbf{X} \mathbf{B}^T \end{split}

\begin{split} \frac{d}{d\mathbf{Y}} \text{Tr}(\mathbf{X}^T \mathbf{A} \mathbf{X} \mathbf{B} \mathbf{Y}^T \mathbf{C}) = \frac{d}{d\mathbf{Y}} \text{Tr}(\mathbf{Y}^T \mathbf{C} \mathbf{X}^T \mathbf{A} \mathbf{X} \mathbf{B}) = \mathbf{C} \mathbf{X}^T \mathbf{A} \mathbf{X} \mathbf{B} \end{split}

So the result is:

\begin{equation} \frac{d}{d\mathbf{X}} \text{Tr}(\mathbf{X}^T \mathbf{A} \mathbf{X} \mathbf{B} \mathbf{X}^T \mathbf{C}) = \mathbf{A} \mathbf{X} \mathbf{B} \mathbf{X}^T \mathbf{C} + \mathbf{A}^T \mathbf{X} \mathbf{C}^T \mathbf{X} \mathbf{B}^T + \mathbf{C} \mathbf{X}^T \mathbf{A} \mathbf{X} \mathbf{B} \end{equation}