Why is the gradient summed in a computational graph for an operation with split output?

374 Views Asked by At

I was looking at back propagating a gradient through a computational graph, and all makes sense aside from when a node has multiple outputs. Take the following function:

$$f(x) = (x+3)(x+2)$$

Which obviously has the derivative:

$$f'(x) = 2x+5$$

Now, take the following segment of a computational graph:

graph 1

The numbers above a line represent the value on a forward pass, the numbers beneath it represent the back propagated gradient.

When moving the gradient backwards through the $f(x)$ node, I took the sum of $0.34$ and $-0.2$, then multiplied this sum by the derivative of $f(x)$, with an input of $2$ (taken from the forward pass). So: $(0.34-0.2)\times f'(2) = 1.26$. I understand I had to multiply the $f'(x)$ (following the chain rule), but I do not understand why I had to sum the $0.34$ and $-0.2$. I only did so, because I know that's what I'm meant to do.

Any help is greatly appreciated.

2

There are 2 best solutions below

3
On

In a computation like $$ f(x)=a(u(x))\cdot b(u(x)) $$ you can derive the gradient rules by starting with the universal relation of automatic/algorithmic differentiation between gradients $\bar f, \bar b,..$ and directional derivatives $\dot x,\dot u,..$ which are $$ \bar f\dot f = \bar a\dot a+\bar b\dot b=\bar u\dot u=\bar x\dot x $$ Now you can insert the chain rule and evaluate it in both directions of the identity chain. $\dot f=\dot ab+a\dot b$, $\dot a = a'(u)\dot u$, $\dot b = b'(u)\dot u$ so that $$ \bar f\dot ab+\bar fa\dot b= \bar a\dot a+\bar b\dot b $$ which has to be true for all direction vectors implying $\bar a=\bar f b$, $\bar b=\bar f a$ and $$ \bar aa'(u)\dot u+\bar bb'(u)\dot u=\bar u\dot u $$ implying $\bar u = \bar aa'(u)+\bar bb'(u)$.

0
On

A 1-input and 2-output gate is equivalent to two 1-input and 1-output gates, as the output functions are the same [$f(x)=(x+3)(x+2)$ in this case].

Now, let the gradient flowing into each of the gates be $g_1$ and $g_2$. The gradient update for each inputs will be $f'(x) * g_1$ and $f'(x) * g_2$.

However, each inputs are the same. Hence, it would be the same if you did $$x = x + f'(x) * g_1$$ $$x = x + f'(x) * g_2$$ or just $$x = x + f'(x) * (g_1 + g_2)$$ at once.