Why is the matrix derivative of the trace of $AB$ with respect to $B$ not a constant, but $A^T$?

2.3k Views Asked by At

Why is this true?

$$\frac{d}{dB} Tr[A B]= A^\top$$

Trace is the sum of the diagonal elements. So, I'm expecting a number, a matrix! What's going on here?

Example: Given matrix A with $m \times n$, and B with $n \times p$, in particular, we have the following,

$$A= \begin{bmatrix} a_{11} & a_{12} & a_{13} \\ a_{21} & a_{22} & a_{23} \end{bmatrix}, B= \begin{bmatrix} b_{11} & b_{12} \\ b_{21} & b_{22} \\ b_{31} & b_{32} \end{bmatrix} $$

To get the answer, we mulitply, take derivative, then, trace.

First, we mutiply, $$AB= \begin{bmatrix} a_{11}b_{11} + a_{12}b_{21} + a_{13}b_{31} & a_{11}b_{11} + a_{12}b_{21} + a_{13}b_{31} \\ a_{21}b_{11} + a_{22}b_{21} + a_{23}b_{31} & a_{21}b_{12} + a_{22}b_{22} + a_{23}b_{32} \end{bmatrix} $$

Second, we take derivative. Fact: $\frac{d}{dB} Tr[A B]= Tr[A \frac{d}{dB}B]$, $$A\frac{d}{d B}= \begin{bmatrix} a_{11}+ a_{12} + a_{13} & a_{11} + a_{12} + a_{13} \\ a_{21} + a_{22}+ a_{23} & a_{21} + a_{22}b + a_{23} \end{bmatrix} $$

Third, we take trace, $$Tr\left[ A\frac{d}{d B}\right] = a_{11} + a_{12} + a_{13} + a_{21} + a_{22}b + a_{23} = k$$

The result is $\frac{d}{dB} Tr[A B] = k$, this a constant. Not the matrix $A^T$ as established in the first equation.

2

There are 2 best solutions below

2
On

First, if $A$ is a matrix $m\times n$, then $B$ has to be a $n\times m$ matrix (otherwise it doesn't make sense to talk about $tr(AB)$.)

Now, you can see $B\mapsto tr(AB)$ as a function from $f:\mathbb{R}^{n\times m}\to \mathbb{R}$ and $\frac{d}{dB}[tr(AB)]$ will be the usual gradient of $f$. This gradient is expected to be some "vector" in $\mathbb{R}^{n\times m}$, hence it may be $A^T$.

The mistake is that you are claiming that $\frac{d}{dB}[tr AB]=tr[A\frac{d}{dB}B]$ this doesn't make sense, the one in the left is a "vector"(matrix), while the one in the right is a constant as you mentioned.

8
On

I think the key point is to understand the meaning of the derivative in this context where $f:\mathbb R^{m \times n} \to \mathbb R$.

My favorite way to think about the gradient of a function $f:\mathbb R^{m \times n} \to \mathbb R$ is

$$\tag{1} f(B + \Delta B) \approx f(B) + \langle \nabla f(B), \Delta B\rangle.$$ In this equation, $\Delta B$ is a matrix (it is added to $B$, after all), and $\nabla f(B)$ is also a matrix (otherwise we could not take the inner product of $\nabla f(B)$ and $\Delta B$).

By the way, what is the inner product we are using here? It is the usual matrix inner product $$\langle C,B \rangle = \text{Tr}(C^T B),$$ which is equivalent to just reshaping the matrices $B$ and $C$ into vectors and then taking the dot product of the resulting vectors.

In this specific problem, we have $$f(B) = \text{Tr}(AB) = \langle A^T, B \rangle.$$ Note that \begin{align} f(B + \Delta B) &= \langle A^T, B + \Delta B \rangle \\ &= \underbrace{\langle A^T, B \rangle}_{f(B)} + \langle A^T, \Delta B \rangle. \end{align} Comparing this with equation (1), we see that $$ \nabla f(B) = A^T. $$