Understanding chain rule for Matrix calculus

460 Views Asked by At

I am trying to understand why : $ \frac{\partial \mathrm{tr}((A\cdot x+b)^\top \cdot (A\cdot x+b))}{\partial x} = 2\cdot A^\top \cdot (b+A\cdot x)$ . A is a matrix, x and b are two vectors.

The chain rule tells us that $\frac{\partial g(u)}{\partial x} = \frac{\partial g(u)}{\partial u} \frac{\partial u}{\partial x} $

Here $u(x) = Ax+b$ and $g(x) = tr(x^\top x) $

$\frac{\partial u}{\partial x} = A^\top $

$\frac{\partial g(u)}{\partial u} =2(Ax+b) $

So I was expecting the result to be $ 2(Ax+b)*A^\top$ . I know that in terms of dimensions it cannot be the correct answer. Where is my mistake ?

2

There are 2 best solutions below

4
On

$ \def\o{{\tt1}}\def\p{\partial} \def\L{\left}\def\R{\right}\def\LR#1{\L(#1\R)} \def\trace#1{\operatorname{Tr}\LR{#1}} \def\grad#1#2{\frac{\p #1}{\p #2}} \def\qiq{\quad\implies\quad} \def\c#1{\color{red}{#1}} $The chain rule works great for scalar equations, or if you're using index notation. But it sucks when you combine it with matrix/vector notation. It leads to all sorts of errors involving transpositions and mismatched dimensions.

The most reliable way to calculate gradients in matrix/vector notation is via differentials. $$\eqalign{ u &= Ax+b \qiq \c{du=A\,dx} \\ g &= u:u \\ dg &= 2u:\c{du} = \LR{2u:\c{A\,dx}} = 2A^Tu:dx \\ \grad{g}{x} &= 2A^Tu = 2A^T(Ax+b) \\ }$$ where $(:)$ denotes the Frobenius product, which is a concise notation for the trace $$\eqalign{ A:B &= \sum_{i=1}^m\sum_{j=1}^n A_{ij}B_{ij} \;=\; \trace{A^TB} \\ A:A &= \big\|A\big\|^2_F \\ }$$ It is sometimes called the double-dot product or double contraction product.
When applied to vectors $({\rm i.e.}\,\;n\!=\!1)\,$ it corresponds to the usual dot product.

The properties of the underlying trace function allow the terms in a Frobenius product to be rearranged in several different ways, e.g. $$\eqalign{ A:B &= B:A \\ A:B &= A^T:B^T \\ C:AB &= CB^T:A = A^TC:B \\ \\ }$$


You can also approach the problem using index notation. The only gradient that you need to know is the following $$\eqalign{ \grad{x_i}{x_j} = \delta_{ij} = \begin{cases} \o\quad{\rm if}\;i=j \\ 0\quad{\rm otherwise} \\ \end{cases} \\ }$$ Then using the Einstein summation convention (wherein a repeated index implies summation over that index), write the equation and calculate its gradient. $$\eqalign{ u_i &= A_{ij}x_j + b_i \\ \grad{u_i}{x_k} &= A_{ij}\delta_{jk} \;=\; A_{ik} \\ \\ g &= u_i u_i \\ \grad{g}{x_k} &= \LR{\grad{u_i}{x_k}}u_i + u_i\LR{\grad{u_i}{x_k}} \\ &= 2\LR{\grad{u_i}{x_k}}u_i \\ &= 2A_{ik} u_i \\ &= 2A_{ki}^T u_i \\ &= 2A_{ki}^T \LR{A_{ij}x_j + b_i} \\ }$$ which, unsurprisingly, is the same as the previous result.

0
On

Background info: If $F: \mathbb R^n \to \mathbb R^m$ is differentiable at a point $x \in \mathbb R^n$, then $F'(x)$ is an $m \times n$ matrix.

Since $A$ is a matrix and $x$ and $b$ are vectors, $(Ax + b)^T (Ax + b)$ is a scalar, and there is no need to take the trace. You are taking the derivative of the function $f(x) = \|Ax + b\|^2$. Notice that $f(x) = g(h(x))$, where $$ h(x) = Ax + b \quad \text{and} \quad g(u) = \|u\|^2. $$ The derivatives of $h$ and $g$ are $h'(x) = A$ and $g'(u) = 2 u^T$. By the chain rule, $$ f'(x) = g'(h(x)) h'(x) = 2(Ax + b)^T A. $$