Gradient of broadcast plus

491 Views Asked by At

Assume $$f(\vec b)=WX\,\tilde{+}\,b$$ where $W$ and $X$ are two matrices, $\vec b$ is a vector, and $\tilde{+}$ symbol is so-called broadcast plus:

$$ \begin{pmatrix} 1 & 2 \\ 3 & 4 \end{pmatrix} \tilde{+} \begin{pmatrix} 5 \\ 6 \end{pmatrix} = \begin{pmatrix} 6 & 7 \\ 9 & 10 \end{pmatrix} $$

How to calculate the gradient matrix of $f(\vec b)$?

2

There are 2 best solutions below

1
On BEST ANSWER

As far as I can see you can replace the broadcast operation by adding the matrix

$$B = \begin{bmatrix}b & b & \dots & b \end{bmatrix}.$$

the gradient of the matrix-valued function $f$ with respect to $b$ is given by

$$\dfrac {\partial f_i}{\partial b_j} = \begin{bmatrix}\delta_{ij}&\delta_{ij}&\ldots &\delta_{ij} \end{bmatrix}.$$

In which $\delta_{ij}=1$ if $i=j$ and $\delta_{ij}=0$ if $i\neq j$.

0
On

"Broadcasting plus" can be rewritten as matrix multiplication:

$$f(\vec b)=WX\,\tilde{+}\,Mb^\top$$ where $M = \begin{bmatrix}1\\1\\1\\\vdots \\1 \end{bmatrix}$ is an $N\times 1$ matrix, and $b^\top$ is a $1\times N$ matrix.

Then, broadcasting $Mb^\top$ becomes literally the same thing as gradient of the normal matrix multiplication $WX$, so the gradient of $f$ over $b^\top$ will be equal to $M$. When used in backpropagation, this will require you to multiply by $M^\top$ or $M$, which gives:

$$\frac{\partial Loss}{\partial b} = \frac{\partial Loss}{\partial f}M $$

This operation will likely simply sum the gradients in each column. The exact notation depends on the task at hand, since you're taking a gradient of a matrix-valued function.

I've also just written a lengthy post on this on my blog.