Applying the chain rule in matrix form to prove Loss function's derivatives...

75 Views Asked by At

I want to prove $\nabla_A J = \nabla_Z J \cdot B^T$ where $Z=AB$. $A$ is a $m \times n$ matrix and $B$ is a $n \times k$ matrix. The function $J$ is not given to me.

I began this proof by first writing $B^T$ as

$$ B^T = \begin{pmatrix} b_{00} & b_{10} & b_{20} \\ b_{01} & b_{11} & b_{21} \\ b_{02} & b_{12} & b_{22} \\ \end{pmatrix} $$

Since J is a loss function, it is scalar value. Computing its derivative wrt to Z i.e. $\nabla_Z J$, I write:

$$ \nabla_{Z} J = \begin{pmatrix} \frac{\partial L}{\partial z_{00}} & \frac{\partial L}{\partial z_{01}} & \frac{\partial L}{\partial z_{02}} \\ \frac{\partial L}{\partial z_{10}} & \frac{\partial L}{\partial z_{11}} & \frac{\partial L}{\partial z_{12}} \\ \frac{\partial L}{\partial z_{20}} & \frac{\partial L}{\partial z_{21}} & \frac{\partial L}{\partial z_{22}} \\ \end{pmatrix} $$

Multiplying these two matrices, I get:

$$ \begin{pmatrix} \frac{\partial L}{\partial z_{00}}b_{00} + \frac{\partial L}{\partial z_{01}}b_{01} + \frac{\partial L}{\partial z_{02}}b_{02}& \frac{\partial L}{\partial z_{00}}b_{10} + \frac{\partial L}{\partial z_{01}}b_{11} + \frac{\partial L}{\partial z_{02}}b_{12} & \frac{\partial L}{\partial z_{00}}b_{20} + \frac{\partial L}{\partial z_{01}}b_{21} + \frac{\partial L}{\partial z_{02}}b_{22} \\ \frac{\partial L}{\partial z_{10}}b_{00} + \frac{\partial L}{\partial z_{11}}b_{01} + \frac{\partial L}{\partial z_{12}}b_{02}& \frac{\partial L}{\partial z_{10}}b_{10} + \frac{\partial L}{\partial z_{11}}b_{11} + \frac{\partial L}{\partial z_{12}}b_{12} & \frac{\partial L}{\partial z_{10}}b_{20} + \frac{\partial L}{\partial z_{11}}b_{21} + \frac{\partial L}{\partial z_{12}}b_{22} \\ \frac{\partial L}{\partial z_{20}}b_{00} + \frac{\partial L}{\partial z_{21}}b_{01} + \frac{\partial L}{\partial z_{22}}b_{02}& \frac{\partial L}{\partial z_{20}}b_{10} + \frac{\partial L}{\partial z_{21}}b_{11} + \frac{\partial L}{\partial z_{22}}b_{12} & \frac{\partial L}{\partial z_{20}}b_{22} + \frac{\partial L}{\partial z_{21}}b_{20} + \frac{\partial L}{\partial z_{22}}b_{22} \\ \end{pmatrix} $$

At this point, I am stuck and I am not sure how to proceed. How do I take the next step? Note, I am not well versed in Matrix Math and it has been a while since I dealt with proves. I am eager to learn and thats why I'm asking for the next step - assuming I did it correctly so far!!

1

There are 1 best solutions below

0
On BEST ANSWER

$\def\p#1#2{\frac{\partial #1}{\partial #2}}$For typing convenience, give short names to the various matrix gradients $$\eqalign{ \nabla_ZJ &= \p{J}{Z} &\doteq G \quad&\iff\quad &dJ = G:dZ \\ \nabla_AJ &= \p{J}{A} &\doteq H \quad&\iff\quad &dJ = H:dA \\ }$$ The differential of $Z$ in term of $A$ $$\eqalign{ Z &= AB \qquad\implies\qquad dZ = dA\;B \\ }$$ can be used to effect a change of variable from $Z\to A$ $$\eqalign{ dJ &= G:dZ \\ &= G:(dA\,B) \\ &= (GB^T):dA \\ &= H:dA &\implies\quad H = GB^T \\ }$$


In the above, a colon is used as a product notation for the trace function, i.e. $$A:B = {\rm Tr}(A^TB)$$ The properties of the trace function allow the terms in a colon product to be rearranged in a number of equivalent ways, e.g. $$\eqalign{ A:B &= B:A = B^T:A^T \\ CA:B &= A:C^TB = C:BA^T \\ }$$