Derivative of trace functions using chain rule

3.3k Views Asked by At

Let us consider $\text{trace}(f(AA^\top))$ where $f$ is some smooth function and $A \in \mathbb{R}^{n \times m}$. Here, function $f$ is a function of matrices (c.f., Higham's books). If $f(x) = x$, then $\text{trace}(f(AA^\top)) = \|A\|_F^2$ (Frobenius norm) and if $f(x) = x^{p/2}$, then $\text{trace}(f(AA^\top))$ is related to Schatten-$p$ norm.

I would like to derivative $\text{trace}(f(AA^\top))$ with respect to matrix $A$. Let $g: X\in \mathbb{R}^{n \times n} \rightarrow \text{trace}(f(X)) \in \mathbb{R}$ and $h: A \in \mathbb{R}^{n \times m} \rightarrow A A^\top \in \mathbb{R}^{n \times n}$. By chain rule, it holds that $$ \frac{\partial g(h(A))}{\partial A} = \frac{\partial g(h(A))}{\partial h} \frac{\partial h(A)}{\partial A}. $$ To the best of my knowledge, $\frac{\partial g(h(A))}{\partial h} \in \mathbb{R}^{n \times n}$ and $\frac{\partial h(A)}{\partial A} \in \mathbb{R}^{n\times n \times n \times m}$. I was confused whether those two terms are productable or not. However, when $p=2$, this corresponds to square of Frobenius norm and its derivative is known as $2 A$.

Any comment or advise would be highly appreciated! Thank you.

2

There are 2 best solutions below

0
On

Let $\mathrm{Mat}(n,\mathbb{R})$ be the vector space of real $n \times n$ matrices. With $f \, : \, \mathrm{Mat}(n,\mathbb{R}) \, \rightarrow \, \mathrm{Mat}(n,\mathbb{R})$ a function which is differentiable on $\mathrm{Mat}(n,\mathbb{R})$, let $g \, : \, X \in \mathrm{Mat}(n,\mathbb{R}) \, \mapsto \, \mathrm{tr}\big( f(X) \big)$.

  1. Because $\mathrm{tr}$ is linear, the differential of $g$ at $X$ is the composition of $\mathrm{tr}$ and the differential of $f$ at $X$ : for $H \in \mathrm{Mat}(n,\mathbb{R})$, it follows from the chain rule that:

$$ D_{X}g \cdot H = \mathrm{D}_{f(X)} \mathrm{tr} \cdot \big( D_{X}f \cdot H \big) = \mathrm{tr}\big( D_{X}f \cdot H \big). $$

As a result : $D_{X}g = \mathrm{tr} \circ \mathrm{D}_{X}f$.

  1. If $f \, : \, X \in \mathrm{Mat}(n,\mathbb{R}) \, \mapsto \, X X^{\top} \in \mathrm{Mat}(n,\mathbb{R})$, the differential of $f$ at $X$ is computed as follows: for $H \in \mathrm{Mat}(n,\mathbb{R})$,

$$ f(X+H) = (X+H)(X+H)^{\top} = XX^{\top} + XH^{\top} + HX^{\top} + HH^{\top}. $$

It follows from the definition of the differential that :

$$ D_{X}f \cdot H = XH^{\top} + HX^{\top}. $$

Finally :

$$ D_{X}g \cdot H = \mathrm{tr}\big( XH^{\top} + HX^{\top} \big) = 2\mathrm{tr}(XH^{\top}). $$

This also gives you the gradient of $g$ at $X$: $\nabla_{X}g = 2X$.

3
On

Let's denote the derivative of the function with a prime
$$\eqalign { f'(x) &= \frac{df(x)}{dx} \cr }$$ The differential of the trace of the function applied to a matrix argument is $$\eqalign { \lambda &= {\rm tr}\big(f(M)\big) \cr d\lambda &= f'(M)^T:dM \cr }$$ where the colon denotes the double-contraction product, i.e. $\,\,A:B={\rm tr}(A^TB)$

In the case that $\,M=AA^T\,$ note that $M$ is symmetric, so $f'(M)^T=f'(M)$. Now we can use the above differential to find the gradient with respect to $A$ $$\eqalign { d\lambda &= f'(M):(dA\,A^T+A\,dA^T) \cr &= f'(M):dA\,A^T + f'(M):A\,dA^T \cr &= f'(M):dA\,A^T + f'(M)^T:dA\,A^T \cr &= 2f'(M):dA\,A^T \cr &= 2f'(AA^T)A:dA \cr \cr \frac{\partial\lambda}{\partial A} &= 2f'(AA^T)A \cr }$$