Understanding a case of multi-variate chain rule for matrix-vector multiplication

116 Views Asked by At

I am trying to understand an example for computing derivatives using matrix calculus in the context of the back-propagation algorithm used in deep learning. If I have a real-valued $K \times K$ matrix $A$, and a vector $x$ of dimension $K$, then let $w = Ax$ and $L = ||w||^2$, and consider $L$ a function of $A$ and $x$.

Then the multi-variate chain rule says that $$\frac{ \partial L }{\partial A} = \frac{ \partial L }{\partial w} \frac { \partial w }{\partial A}$$

Now I am trying to figure out what the shape of the matrices in this equation should be. I know $$\frac{\partial L} {\partial w} = [2 w_1, \ldots, 2 w_K]$$ i.e a tensor of shape $(1, K)$.

Then the shape of $\frac{\partial w} {\partial A}$ is confusing me. Since $w$ is a $K$ dimensional vector and $A$ is of size $K \times K$, it seems that if you flatten $A$ to a vector (in either row major or column major order), and index its entries from $A_i$ as $1 \leq i \leq K^2$, you could treat $\frac{\partial w} {\partial A}$ as a matrix of size $K \times K^2 $, but I feel this is not the convention. What is the shape of $\frac{\partial w} {\partial A}$ ?

Also $L$ is a scalar-valued function, so I think $\frac{\partial L} {\partial A}$ should be a vector of size $K^2$, or a matrix of size $K \times K$, but I am not sure about this either. Any insights appreciated.

2

There are 2 best solutions below

1
On BEST ANSWER

The trace/Frobenius product is $$\eqalign{ A:B &= \sum_{i=1}^m\sum_{j=1}^n A_{ij}B_{ij} \;=\; {\rm Tr}(A^TB) \\ A:A &= \big\|A\big\|_F^2 \;=\; {\rm Tr}(A^TA) \\ }$$ Write the cost function using this notation. Then calculate its differential and gradient. $$\eqalign{ L &= \|w\|^2 = w:w \\ dL &= 2w:dw \;=\; 2w:(dA\,x) \;=\; 2wx^T:dA \\ \frac{ \partial L }{\partial A} &= 2wx^T \;=\; 2Axx^T \\ }$$ The differential approach obviates the need to calculate higher-order tensors, such as matrix-by-vector, vector-by-matrix, or matrix-by-matrix gradients.

It also avoids common pitfalls involving transposed matrices or vectors.

4
On

Assume, for simplicity, $K=2$, then

$$ \partial_A L = \left( \begin{array}{cc} \partial_{a_{11}} L && \partial_{a_{12}} L \\ \partial_{a_{21}} L && \partial_{a_{22}} L \end{array} \right). $$

If $L$ is a function of the vector ${w}(a_{ij})$, then

$$ \partial_{a_{ij}} L({w}(a_{ij})) = \partial_{w_1}L \; \partial_{a_{ij}} w_1 + \partial_{w_2}L \; \partial_{a_{ij}} w_2. $$

Given in matrix notation

$$ \partial_A L = \left( \begin{array}{c} \partial_{w_1} L \; I_{2 \times 2} && \partial_{w_2} L \; I_{2 \times 2} \\ \end{array} \right) \left( \begin{array}{c} \partial_A w_1 \\ \partial_A w_2 \end{array} \right). $$

Hence, the general shape of $\partial_Aw$ is $K^2 \times K$.