Chain rule where intermediate variable is a matrix

260 Views Asked by At

How does one calculate the derivative of a scalar with respect to a matrix using the chain rule where the intermediate variable is a matrix? For example:

$$\frac{\partial L}{\partial \mathbf W} = \frac{\partial L}{\partial \mathbf Y} \frac{\partial \mathbf Y}{\partial \mathbf W}$$

If $\mathbf Y$ were a vector ($\mathbf y$), the chain rule would suggest that we need to sum across all the individual elements of $\mathbf y$, i.e.

$$\frac{\partial L}{\partial \mathbf W} = \sum_i \frac{\partial L}{\partial y_i} \frac{\partial y_i}{\partial \mathbf W} \text{, where $y_i$ is an element of vector $\mathbf y$}$$

Is it OK to assume that the extension of that rule to the case where $\mathbf Y$ is a matrix is as follows?

$$\frac{\partial L}{\partial \mathbf W} = \sum_{i,j} \frac{\partial L}{\partial \mathbf Y_{i,j}} \frac{\partial \mathbf Y_{i,j}}{\partial \mathbf W}\text{, where $\mathbf Y_{i,j}$ is an element of matrix $\mathbf Y$}$$

1

There are 1 best solutions below

2
On

Here is a concrete example of the differential approach.

Assume $Y$ is a matrix and the cost function is given by $$L=\|Y\|^2_F = Y:Y$$ where the colon is a convenient product notation for the trace, i.e. $$A:B={\rm Tr}(A^TB)$$ Let's further assume that the relationship to the matrix $W$ is $$Y = X^TW$$ Calculate the differential of the cost function, and then its gradient. $$\eqalign{ dL &= 2Y:dY \cr&= 2Y:X^TdW \cr&= 2XY:dW \cr \frac{\partial L}{\partial W} &= 2XY \cr }$$ If you're working with vectors $(w,y)$ instead of matrices, the derivation is basically unchanged. $$\eqalign{ y &= X^Tw \cr L &= \|y\|^2_F = y:y \cr dL &= 2y:dy \cr&= 2y:X^Tdw \cr&= 2Xy:dw \cr \frac{\partial L}{\partial w} &= 2Xy \cr }$$