Im reading Matrix Differential Calculus by Magnus and I am having some difficulty understanding the chain rule. My goal is to write out the matrix form of an arbitrary machine learning model, and then use differentials to find the gradient w.r.t weights, but I am having trouble understanding the connection between differentials/gradients and how to move between them.
Let $A\in \mathbb{R}^{n \times m}$ be some matrix, $x \in \mathbb{R}^{n}$ a column vector, and $g:\mathbb{R}^{m} \rightarrow \mathbb{R}^{m}$ some differentiable function. Then define the following function.
$$f: \mathbb{R}^n \rightarrow \mathbb{R}^{m} \ \ \text{given by} \ \ f(x) = g(A^Tx) $$
Here is my attempt at finding the differtial of $f$.
- $\ \textbf{d}f = \textbf{d}g(A^T x)$
- $\ \ \ \ \ \ = \big[\textbf{D}g(y) \big]\textbf{d} \big(A^Tx \big)$
- $\ \ \ \ \ \ = \big[\textbf{D}g(y) \big]\bigg((\textbf{d}A)^Tx + A^T(\textbf{d}x) \bigg)$
Where $\big[\textbf{D}g(y) \big]$ is the jacobian of $g$ at the point $y = A^Tx$. I think (3) is correct, but now trying to move to gradients I get lost. My understanding is that to find gradients I manipulate (3) to look like
- $ \ \textbf{d} f = (J) \textbf{d}x $
- $ \ \textbf{d} f = (K) \textbf{d}A $
Then $J$ would be $\frac{\partial f}{\partial x}$ and $K$ would be $\frac{\partial f}{\partial A}$ (i think). How can I do this? If I try to find $\frac{\partial f}{\partial x}$ I am tempted to treat $A$ as a constant in (3) so that $(\textbf{d}A)^T = 0$, and then get
$$\textbf{d} f = \big[\textbf{D}g(y) \big]A^T(\textbf{d}x) $$ $$\Rightarrow \frac{\partial f}{\partial x} = \big[\textbf{D}g(y) \big]A^T $$ $$\Rightarrow \nabla_x f = A \big[\textbf{D}g(y) \big]^T$$
These dimensions don't seem right to me. If I consider $x$ to be a parameter of a loss function $f$ that I want to minimize, then I should find $\nabla_x f$ and update $x$ via gradient descent. However the above implies that $\nabla_x f$ is $n \times m$; how can I use this to update $x$ which is $n \times 1$?
What am I doing wrong here?
First, you need to understand what $[dg(A^Tx)]d(A^Tx)$ really is. This says that you have a differential of $g$, which is a continuous linear operator of dimensions $m\times m$ (exactly as g), at the point $A^Tx$, it is written as $dg(A^Tx)$. Then you have the input $d(A^Tx)$ (a vector of dimension $m$) to this linear operator $dg(A^Tx)$. As $f$ is a function of $x$ only, you indeed need to treat $A$ as a constant. And then the answer is indeed $$df=dg(A^Tx)A^Tdx$$ where the dimensions of $dg(A^Tx)$ are $m\times m$, because it is the differential of $g(y)$ at the point $y=A^Tx$ and not the differential of $g(A^Tx)$.