How to add the derivative of a matrix to the chain rule?

214 Views Asked by At

In machine learning, I'm optimizing a parameter matrix $W$.

The loss function is$$L=f(y),$$where $L$ is a scalar, $y=Wx$, $x\in \mathbb{R}^n$, $y\in \mathbb{R}^m$ and the order of $W$ is $m\times n$.

In all math textbooks, it is usually$$\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y}\frac{\partial y}{\partial x}=\frac{\partial L}{\partial y}W.$$Where $\dfrac{\partial L}{\partial y}$ is a $1\times m$ vector. This is quite easy to understand.

However, in machine learning, $x$ is the input and $W$ is the parameter matrix to optimize, it should be$$\frac{\partial L}{\partial W}=\frac{\partial L}{\partial y}\frac{\partial y}{\partial W}.$$But what is $\dfrac{\partial y}{\partial W}$? Is it $x$? Is it correct?

According to wikipedia, the derivative of a scalar to a matrix is a matrix

\begin{equation*} \frac{\partial L}{\partial W} = \begin{pmatrix} \frac{\partial L}{\partial W_{11}} & \frac{\partial L}{\partial W_{21}} & \cdots & \frac{\partial L}{\partial W_{m1}} \\ \frac{\partial L}{\partial W_{12}} & \frac{\partial L}{\partial W_{22}} & \cdots & \frac{\partial L}{\partial W_{m2}} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial L}{\partial W_{1n}} & \frac{\partial L}{\partial W_{2n}} & \cdots & \frac{\partial L}{\partial W_{mn}} \end{pmatrix} \end{equation*}

where $$\frac{\partial L}{\partial W_{ji}}=\frac{\partial L}{\partial y_j}\frac{\partial y_j}{\partial W_{ji}}=\frac{\partial L}{\partial y_j}x_i$$

therefore

\begin{equation*} \frac{\partial L}{\partial W} = \begin{pmatrix} \frac{\partial L}{\partial y_1}x_1 & \frac{\partial L}{\partial y_2}x_1 & \cdots & \frac{\partial L}{\partial y_m}x_1 \\ \frac{\partial L}{\partial y_1}x_2 & \frac{\partial L}{\partial y_2}x_2 & \cdots & \frac{\partial L}{\partial y_m}x_2 \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial L}{\partial y_1}x_n & \frac{\partial L}{\partial y_2}x_n & \cdots & \frac{\partial L}{\partial y_m}x_n \\ \end{pmatrix} \end{equation*}

Does this even fit the chain rule?

To fit the chain rule $$\frac{\partial L}{\partial W} = \frac{\partial L}{\partial y}\frac{\partial y}{\partial W}$$ $\dfrac{\partial L}{\partial W}$ is a $n*m$ matrix, $\dfrac{\partial L}{\partial y}$ is a $1\times m$ vector, how to fit it?

PS: I just found there is an operation called kronecker product, and $\dfrac{\partial L}{\partial W}$ can be written as $\dfrac{\partial L}{\partial y}\bigotimes x$, but this is still beyond me. First, why does the chain rule lead to kronecker product? Isn't the chain rule about matrix multiplication?

Second, does this mean $\dfrac{\partial y}{\partial W} = x$? I didn't see the definition of the derivative of a vector to a matrix in wikipedia.

The third and most important question is, even I know the derivative $\dfrac{\partial L}{\partial W}$, how should I update my parameter matrix? We all know the gradient descent works because of directional derivative $$\nabla_v f = \frac{\partial f}{\partial v}v$$ so we should take the negative gradient direction to lower $f$.

Does this even exist for the derivative of a matrix? I mean $\dfrac{\partial L}{\partial W}$ multiplies $\Delta W$ won't reproduce $\Delta L$ anyway.

2

There are 2 best solutions below

6
On BEST ANSWER

First, it is important to keep track of what functions you are differentiating. You have an $m$-vector $y$ with component functions $y_i$ ($i=1,...,m$), and you are differentiating by an $(m\times n)$-matrix $W$. Therefore, your derivative should depend on three separate functional indices, i.e., it should be a third-order tensor $$ \left(\frac{\partial y}{\partial W}\right)_{ijk} = \frac{\partial y_i}{\partial W_{jk}}, \qquad 1\leq i,j\leq m,\quad 1\leq k \leq n. $$ So, this tells you immediately that the derivative on the left cannot possibly be $x$. To compute what it is, you should look for the linear mapping satisfying $$ dy_i = \sum_{j,k}\left(\frac{\partial y}{\partial W}\right)_{ijk} dW_{jk}\qquad \mathrm{for\,all\,} 1\leq i\leq m,$$ where $d$ denotes the exterior derivative. This can be accomplished explicitly by differentiating in the following way: $$ dy = d(Wx) = (dW)x = (I\otimes x):dW,$$ where we have used that for all $1\leq i \leq m$, $$ \left((dW)x\right)_i = \sum_k dW_{ik}x_k = \sum_{j,k} \delta_{ij}x_k dW_{jk} = \left((I\otimes x):dW\right)_{\,i}\,.$$ Therefore, the derivative you are looking for is simply the tensor product $I\otimes x.$

To answer your related question about the chain rule: it still works, but you have to be careful. Using your example, in components we have $$\left(f'(y)\,y'(W)\right)_{jk} = \sum_i \frac{\partial f}{\partial y_i} \frac{\partial y_i}{\partial W_{jk}} = \sum_{i}\frac{\partial f}{\partial y_i}\delta_{ij}x_k = \frac{\partial f}{\partial y_j}x_k = \left(\frac{\partial f}{\partial y} \otimes x\right)_{jk},$$ so that the total derivative depends on $m\times n$ functions. I leave it to you to check that this makes sense given that $f$ is a function of $W$.

With this, updating the parameters $W$ through gradient descent is straightforward. For each component, simply compute $$W_{jk} \gets W_{jk} - \eta \frac{\partial f}{\partial W_{jk}} = W_{jk} - \eta \frac{\partial f}{\partial y_j} x_k, $$ where $\eta>0$ is your step-size.

0
On

For this computation, a coordinate based approach is not difficult. You can simply think of $\frac{\partial{y}}{\partial W}$ as being represented by the coordinates $\frac{\partial y_k}{\partial w_{ij}}$ for $k \in [m], i \in [m], j \in [n]$. We have $$y_k = \sum_{i = 1}^{n} w_{ki}x_i$$ so $$\frac{\partial y_{k}}{w_{ij}} = \delta_{ik}x_j.$$

We can also obtain this same result using a more abstract approach involving the Frechet derivative. You can write $y(W) = Wx$. Then the Frechet derivative at a point $M \in M(m \times n, \mathbb{R})$ is $Dy(W) : M(m \times n, \mathbb{R}) \to \mathbb{R}^m$ given by $$Dy(W)H = Hx.$$ In other words, $Dy(W)$ is the linear operator of right multiplication by $x$. Hence the coordinates of $Dy(W)$ with respect to the standard bases for $M(m \times n, \mathbb{R})$ and $\mathbb{R}^m$ are $\frac{\partial{y}}{\partial w_{ij}} = Dy(W)e_ie_j^T = e_{i}e_j^Tx = x_je_i$. That is, $\frac{\partial y_{k}}{w_{ij}} = x_{j}\delta_{ik}$, in agreement with what as obtained earlier.