Matrix derivative confusion

46 Views Asked by At

$f:\mathbb{R^n}\to\mathbb{R}$
$a=z$
$z=Wa'$
Where $a,z,a'$ are vectors of shapes $n\times 1, n\times 1, m\times 1$ and $W$ is a matrix of shape $n\times m$.
I need $\frac{\partial f(a)}{\partial W}$.

Applying chain rule, $$\frac{\partial f(a)}{\partial W}=\frac{\partial f}{\partial a}\frac{\partial a}{\partial z}\frac{\partial z}{\partial W}$$ Now $\frac{\partial z}{\partial W}=a'$. The second term here is of shape $n\times n$. But $a'$ has shape $m\times 1$. So I must have committed a mistake somewhere.
Can someone point me in the right direction? I finally want to compute the above derivative. ($\frac{\partial f}{\partial W}$)