Help deriving the vector-jacobian product of some operations

161 Views Asked by At

I'm building a simple auto-differentiation (AD) engine for my own educational purpose. The end goal will be to train a neural network. I decided to go for the reverse-mode AD and thus I'm currently deriving and implementing a bunch of vector-jacobian product (vjp) primitives. Vjps for point-wise operations like addition and multiplication are fairly easy but I'm struggling to come up with the solutions for a vector-matrix product and matrix product. More precisely let say I have the following operation : $$z = W.x$$ with $W$ a matrix $\in\mathbb{R}^{n\times m}$, $x$ a vector $\in\mathbb{R}^{m}$ and thus $z$ a vector $\in\mathbb{R}^{n}$. For some vector $v \in\mathbb{R}^{n}$, the vjp of $$v.J_x(z)$$ is simply $$v.W$$ But for the more complicated case where we want the vjp for $$v.J_W(z)$$ I struggle to come up with an expression since $J_W(z)$ is a Tensor. From what I could gather from a simple derivation with $n=m=2$, this tensor is mostly sparse with $x_{i}$'s in its "diagonal", something like $$\begin{pmatrix} (x_1 & x_2) & (0 & 0)\\ (0 & 0) & (x_1 & x_2) \end{pmatrix}$$ Not sure if that make sense since I don't know how to properly manipulate Tensor and how do the product with $v$. Does someone knows an expression for the vjp of this operation ? or how it is implemented in AD engine such as pytorch or TensorFlow.

1

There are 1 best solutions below

0
On

I think it is best to derive the result of the vjp in index notation, working with higher tensors in symbolic notation is tricky.

The vjp you look for, i.e., $$v.J_W(z)$$ is in index notation (assuming Einstein summation convention) $$v_i\frac{\partial z_i}{\partial W_{kl}}$$ In order to evaluate the partial derivative, we also need the forward computation in index notation, which would be $$z_i = W_{ij}x_j$$ If we evaluate its derivative, we get $$\frac{\partial z_i}{\partial W_{kl}} = \frac{\partial W_{ij}}{\partial W_{kl}} x_j = \delta_{ik}\delta_{jl}x_j$$ with the Kronecker delta $$\delta_{ij} = \begin{cases} 1, \quad i=j \\ 0, \quad i \ne j \end{cases} $$ In these derivations using the index notation, Kronecker deltas are just a tool to express the derivative of a quantity with respect to itself, but with different indices.

We can plug the derivative back into the vjp definition $$v_i\frac{\partial z_i}{\partial W_{kl}} = v_i \delta_{ik}\delta_{jl}x_j$$ The Kronecker deltas induce an index change $i \leftrightarrow k$ and $j \leftrightarrow l$. Let's choose to keep indices $k$ and $l$ as those are the ones that are left-over on the left-hand side of the equation. Then, we have $$v_i\frac{\partial z_i}{\partial W_{kl}} = v_k x_l$$ The quantity we got is an outer product of vector $v$ and vector $x$.

Hence, we can also go back to symbolic notation to finally find the answer to your question $$v.J_W(z) = v.x^T$$

I hope that helped. Recently, I also uploaded a YouTube video explaining the steps in more detail.