Gram matrix differential

547 Views Asked by At

I have the matrix-valued function $M=AXX^\top A^\top$, and I am looking for an expression for its derivative with respect to each element of $X$. Here $A$ is $n\times n$ and $X$ is $n\times m$ with $m\leq n$

So I am looking for something of the form $\frac{\partial M}{\partial X_{ij}} = D^{ij}$ with $D^{ij}_{kl} = \frac{\partial M_{kl}}{\partial X_{ij}}$.

Digging through the Matrix Cookbook, I was not able to find this or any more general expressions that seemed like they would be helpful.

The closest I could find was the identity:$$\frac{\partial}{\partial X} b^\top X^\top X c = X (bc^\top + cb^\top)$$ where $b$ and $c$ are vectors. I feel like this may contain the solution, but I am unsure how to generalize it.

2

There are 2 best solutions below

0
On BEST ANSWER

We can compute the derivative $\frac{dM}{dX}$ as a directional derivative: \begin{align*} \frac{dM}{dX}(Y) &= \lim_{t \to 0} \frac{M(X + tY) - M(X)}{t} = \lim_{t \to 0} \frac{A(X+tY) (X+tY)^\top A^\top - AXX^\top A^\top}{t} \, . \end{align*} Simplifying the numerator, we find \begin{align*} A(X+tY) (X^\top+tY^\top) A^\top - AXX^\top A^\top &= A(X+tY) (X^\top+tY^\top) A^\top - AXX^\top A^\top\\ &= A (X X^\top + tY X^\top + tXY^\top + t^2 Y Y^\top) A^\top - AXX^\top A^\top\\ &= tA(Y X^\top + XY^\top)A^\top + t^2 A Y Y^\top A^\top \, . \end{align*} Thus \begin{align*} \lim_{t \to 0} \frac{A(X+tY) (X+tY)^\top A^\top - AXX^\top A^\top}{t} &= \lim_{t \to 0} \frac{tA(Y X^\top + XY^\top)A^\top + t^2 A Y Y^\top A^\top}{t}\\ &= \lim_{t \to 0} A(Y X^\top + XY^\top)A^\top + t A Y Y^\top A^\top\\ &= A(Y X^\top + XY^\top)A^\top \, . \end{align*}

As mentioned in the comments, the partial $\frac{\partial M}{\partial X_{ij}}$ can then be found by substituting $Y = E_{ij}$, the matrix with a $1$ in the $i,j$ entry and zeroes elsewhere. So \begin{align*} \frac{\partial M}{\partial X_{ij}} &= \frac{dM}{dX}(E_{ij}) = A(E_{ij} X^\top + XE_{ij}^\top)A^\top = A(E_{ij} X^\top + X E_{ji})A^\top \, . \end{align*}

0
On

$ \def\p{\partial} \def\qiq{\quad\implies\quad} \def\grad#1#2{\frac{\p #1}{\p #2}} \def\c#1{\color{red}{#1}} $As hinted in the comments, the gradient of a matrix with respect to its own components is $$\eqalign{ \grad{X}{X_{ij}} = \c{E_{ij}} \\ }$$ Applied to your $M$ function this yields $$\eqalign{ M &= AXX^TA^T \\ \grad{M}{X_{ij}} &= A\c{E_{ij}}X^TA^T + AX\c{E_{ij}}^TA^T \\\\ }$$


The $\{\c{E_{ij}}\}$ represent the standard basis for matrices, similar to how the $\{\c{e_k}\}$ are the standard basis for vectors. Such bases allow you to write any matrix (or vector) as a summation over its scalar components, e.g. $$ M = \sum_{i=1}^m\sum_{j=1}^n M_{ij}\c{E_{ij}} \quad\iff\quad v = \sum_{k=1}^n v_{k}\c{e_{k}} $$ Such expansions lead immediately to the component-wise gradient at the top of this post.