How to efficiently calculate the gradient of the diagonal of matrix multiplication

510 Views Asked by At

I have a diagonal matrix $\Gamma = \text{diag}(K_{NM}K_M^{-1}K_{MN})$, where

  • $K_{NM}$ is a $N\times M$ matrix
  • $K_{MN} = K_{NM}^T$,
  • $K_M^{-1}$ is a $M \times M$ matrix and is a SPD(symmetric and positive definite) matrix.

The matrices $K_{NM}$ and $K_M$ are functions of $\theta$, Now I want to calculate the gradient of $\Gamma$

$$ \frac{\partial \Gamma}{\partial \theta} = \text{diag}(K_{NM} K_{M}^{-1} \frac{\partial K_{NM}}{\partial \theta} + \frac{\partial K_{NM}}{\partial \theta} K_{M}^{-1} K_{MN} - K_{NM}K_M^{-1}\frac{\partial K_M}{\partial \theta} K_M^{-1}K_{MN}) $$

The matrices $K_{NM} K_M^{-1}$ and $K_M^{-1}$ have already been pre-computed, are there any efficient ways to calculate $\displaystyle \frac{\partial \Gamma}{\partial \theta}$ in $O(MN)$ complexity?

1

There are 1 best solutions below

0
On

Let $\Gamma=diag(XY^{-1}X^T)$. Then $\Gamma'=diag(2(XY^{-1})X'^T-(XY^{-1})Y'(XY^{-1})^T)$.

You know $U=XY^{-1}\in M_{N,M}$.

Then, for every $i\leq N$, $\Gamma_i'=\sum_j 2U_{i,j}X_{i,j}'-\sum_{j,k}U_{i,j}Y_{j,k}'U_{i,k}$ has complexity $O(M^2)$. Then the calculation of $\Gamma'$ has complexity $O(M^2N)$.