Derivative of $\mathrm{diag}\left(\boldsymbol{A}\boldsymbol{x}+\boldsymbol{b}\right)\boldsymbol{x}$

92 Views Asked by At

The task is to compute the derivative \begin{equation} \frac{\partial\mathrm{diag}\left(\boldsymbol{A}\boldsymbol{x}+\boldsymbol{b}\right)\boldsymbol{x}}{\partial \boldsymbol{x}} \;\;\text{with} \;\; \mathrm{diag}\left(\boldsymbol{A}\boldsymbol{x}+\boldsymbol{b}\right)= \begin{pmatrix} \left(\boldsymbol{A}\boldsymbol{x}+\boldsymbol{b}\right)_1 & 0 & 0 & \dots \\ 0 & \left(\boldsymbol{A}\boldsymbol{x}+\boldsymbol{b}\right)_2 & 0 & \dots \\ 0 & 0 & \left(\boldsymbol{A}\boldsymbol{x}+\boldsymbol{b}\right)_3 & \\ \vdots & \vdots & &\ddots \end{pmatrix}. \end{equation} To this end, we can use $\mathrm{diag}\left(\boldsymbol{A}\boldsymbol{x}+\boldsymbol{b}\right)\boldsymbol{x}=\left(\boldsymbol{A}\boldsymbol{x}+\boldsymbol{b}\right) \circ \boldsymbol{x}$, where $\circ$ denotes the component-wise product/Hadamard product. Deriving the $i$-th row of this expression with respect to the $p$-th component of $\boldsymbol{x}\in\mathbb{R}^n$ results in \begin{align} \frac{\partial \left(\sum_{j=1}^{n} A_{ij}x_j+b_i\right)x_i}{\partial x_p}&= \frac{\partial \left(\sum_{j=1}^{n} A_{ij}x_j+b_i\right)}{\partial x_p}x_i+ \left(\sum_{j=1}^{n} A_{ij}x_j+b_i\right)\frac{\partial x_i}{\partial x_p} \\ &= A_{ip} x_i + \sum_{j=1}^n A_{pj} x_j + b_p. \end{align} Hence, when $\boldsymbol{a}_p^\top$ is the $p$-th row of $\boldsymbol{A}$, the full the derivative is $$ \frac{\partial \left(\mathrm{diag}\left(\boldsymbol{A}\boldsymbol{x}+\boldsymbol{b}\right)\boldsymbol{x}\right)}{\partial \boldsymbol{x}}=\begin{pmatrix} A_{11}x_1+\boldsymbol{a}_1^\top\boldsymbol{x}+b_1 & A_{12} x_1+\boldsymbol{a}_2^\top \boldsymbol{x}+b_2 & \dots \\ A_{21}x_2+\boldsymbol{a}_1^\top\boldsymbol{x}+b_1 & A_{22} x_2+ \boldsymbol{a}_2^\top \boldsymbol{x}+b_2 & \dots \\ \vdots & \vdots \end{pmatrix}. $$

Is this correct?

======== Correction ========= \

Comparing with the answers, the first part $A_{ip}x_i$ corresponds to $\mathrm{diag}\left(\boldsymbol{x}\right)\boldsymbol{A}$. However, the second term $$ \left(\sum_{j=1}^{n} A_{ij}x_j+b_i\right)\frac{\partial x_i}{\partial x_p} $$ only makes a contribution if $i=p$, i.e, on the diagonal. Using $A_{ij}\delta_{ip}=A_{pj}$ and $b_i\delta_{ip}=b_p$ got rid of the $i$ and led to the error.

2

There are 2 best solutions below

2
On BEST ANSWER

Matrix calculus approach using differentials (I would suggest to consider such approaches rather than elementwise which are prone to error in my humble opinion)

--8< ----------------------------------

Let $y = \operatorname{Diag}\left(A x + b \right) x$, where $\operatorname{Diag}$ creates a diagonal matrix.

Using differentials, \begin{align} dy &= \operatorname{Diag}\left(A dx \right) x + \operatorname{Diag}\left(A x + b \right) dx \\ &= \operatorname{Diag}\left(x\right) A dx + \operatorname{Diag}\left(A x + b \right) dx \end{align}

The gradient is $$\frac{\partial y}{\partial x} = \operatorname{Diag}\left(x\right) A + \operatorname{Diag}\left(A x + b \right). $$

Now, you can cross check with your answer.

--- ADDENDUM ---

Let $a$ and $b$ be the vectors of the same dimension. Then it is straightforward to show that $$\operatorname{Diag}\left( a \right) b = a \odot b = b \odot a = \operatorname{Diag}\left( b \right) a .$$

2
On

Using substitutions $u=Ax+b$ and $v=x$ we have

$$\newcommand{\diag}{\operatorname{diag}}\begin{aligned} \frac{ \diag(Ax+b)x}{x} &= \frac{ \diag(u)v}{ (u,v)} ∘ \frac{ (u,v)}{ x} \\ &= \Bigg[\begin{pmatrix} ∆u\\ ∆v\end{pmatrix} ⟼ \diag(∆u)v + \diag(u) {∆v} \Bigg] ∘ \Bigg[{∆x}⟼\begin{pmatrix} A{∆x} \\ {∆x}\end{pmatrix}\Bigg] \\ &= \Bigg[{∆x}⟼\diag(A{∆x})x + \diag(Ax+b) {∆x}\Bigg] \end{aligned}$$

Which is the derivative in functional form. To get it in matrix/tensorial form we need to express it as ${∆x}⟼T⋅{∆x}$, which we can do by noting that

$$\diag(A{∆x})x = (A{∆x})⊙x = x⊙(A{∆x}) = \diag(x)A{∆x} $$

Hence the derivative is

$$\begin{aligned} \frac{ \diag(Ax+b)x}{x} &= \Big[{∆x}⟼ \big(\diag(x)A + \diag(Ax+b)\big) {∆x}\Big] \\ &\cong \diag(x)A + \diag(Ax+b) \end{aligned}$$