How to using chain-rule to calculate the gradient in flow chart?

196 Views Asked by At

I have an data flow chart as follow

image

The $a$ and $x_1,x_2,x_3$ are vector, $W$ is the matrix

Output is the

$$y = ((aW+x_1)W +x_2)W+x_3)$$

How to use chain rule to compute

$\frac{dy}{dW}$?

My try is follow:

$$ t_1= aW+x_1$$

$$ t_2= t_1W +x_2$$

$$ t_3= t_2W +x_3$$

$$y = t_3$$

Now the problem is how to use chain rule since $t_1, t_2,t_3$ all contain with $W$

1

There are 1 best solutions below

0
On BEST ANSWER

Take differentials of the three equations that you derived. $$\eqalign{ dt_1 &= a\,dW \cr dt_2 &= t_1\,dW + dt_1\,W \cr &= t_1\,dW + (a\,dW)\,W \cr dt_3 &= t_2\,dW + dt_2\,W \cr &= t_2\,dW + \big(t_1\,dW + a\,dW\,W\big)\,W \cr }$$ The desired gradient is a 3rd order tensor, use a vec-operation to flatten it into a matrix. $$\eqalign{ dy &= (I\otimes t_2^T + W\otimes t_1^T + W^{2}\otimes a^T)^T\,{\rm vec}(dW) \cr \frac{\partial y}{\partial w} &= (I\otimes t_2^T + W\otimes t_1^T + W^{2}\otimes a^T)^T \cr }$$