Finding gradient with w.r.t. matrix input(s)

83 Views Asked by At

Suppose that $A \in \mathbb{R}^{1\times \ell}$, $B \in \mathbb{R}^{\ell \times m}$ are fixed matrices and ${\bf z} \in \mathbb{R}^d$ is a fixed vector. Let $V \in \mathbb{R}^{m\times k}$, $B \in \mathbb{R}^{k \times d}$, and define $f(\Phi(V,W)) = A\,\sigma_2\left(B\,V\,\sigma_1(W\,{\bf z})\right)$ with $f \colon \mathbb{R}^\ell \to \mathbb{R}$ and $\Phi \colon \mathbb{R}^{m \times k} \times \mathbb{R}^{k \times d} \to \mathbb{R}^\ell$ defined by $$f({\bf y}) = A\,\sigma_2({\bf y}) \quad \text{and} \quad \Phi(V,W) = B\,V\,\sigma_1(W\,{\bf z}),$$ respectively. Here $\sigma_2 \colon x \mapsto x^2$ is a square function applied component-wise to its argument, and $\sigma_1 \colon x \mapsto \max\{x,0\}$ denotes a component-wise application of the ReLU nonlinear function. I am stuck at the exact computation of the gradient of $$f(\Phi(V,W)) = A\,\sigma_2\left(B\,V\,\sigma_1(W\,{\bf z})\right)$$ with respect to its (matrix) arguments $V$ and $W$, i.e., I would like to know $\nabla_V f(\Phi(V,W))$ and $\nabla_W f(\Phi(V,W))$.


Attempt: To derive the expression for $\nabla_V f(\Phi(V,W))$ I resort to a perturbation argument. Denote by $a \odot b$ the Hadamard product of of two matrices (or vectors) $a$ and $b$, we can write \begin{align*} f(\Phi(V+\Delta V,W)) &= A\,\sigma_2\left(B\,(V + \Delta V)\,\sigma_1(W\,{\bf z})\right) \\ &= A\,\left(B\,V\,\sigma_1(W\,{\bf z}) + B\,\Delta V\,\sigma_1(W\,{\bf z})\right)\odot \left(B\,V\,\sigma_1(W\,{\bf z}) + B\,\Delta V\,\sigma_1(W\,{\bf z})\right) \\ &= f(\Phi(V,W)) + 2\,A\,\left(B\,V\,\sigma_1(W\,{\bf z})\right)\odot \left(B\,\Delta V\,\sigma_1(W\,{\bf z})\right) + \mathcal{O}(\|\Delta V\|^2). \end{align*} However, I am not sure how to read from the previous identity about the expression of $\nabla_V f(\Phi(V,W))$. The argument needed to find $\nabla_W f(\Phi(V,W))$ seems a bit harder...

2

There are 2 best solutions below

2
On

$ \def\l{\lambda} \def\g{\gamma} \def\G{\Gamma} \def\o{{\tt1}} \def\BR#1{\Big(#1\Big)} \def\LR#1{\left(#1\right)} \def\op#1{\operatorname{#1}} \def\diag#1{\op{diag}\LR{#1}} \def\Diag#1{\op{Diag}\LR{#1}} \def\r#1{\op{ReLu}\LR{#1}} \def\s#1{\op{Step}\LR{#1}} \def\trace#1{\op{Tr}\LR{#1}} \def\qiq{\quad\implies\quad} \def\qif{\quad\iff\quad} \def\p{\partial} \def\grad#1#2{\frac{\p #1}{\p #2}} \def\c#1{\color{red}{#1}} \def\CLR#1{\c{\LR{#1}}} $The Frobenius $(:)$ product is a concise notation for the trace $$\eqalign{ P:Q &= \sum_{i=1}^m\sum_{j=1}^n P_{ij}Q_{ij} \;=\; \trace{P^TQ} \\ Q:Q &= \|Q\|^2_F\qquad \{ {\rm Frobenius\:norm} \} \\ }$$ When applied to vectors $(n=\o)$ it reduces to the standard dot product.

The properties of the underlying trace function allow the terms in a Frobenius product to be rearranged in many useful ways, e.g. $$\eqalign{ P:Q &= Q:P \\ P:Q &= P^T:Q^T \\ S:\LR{PQ} &= \LR{SQ^T}:P &= \LR{P^TS}:Q \\ }$$ We'll also need the Heaviside Step function (aka the derivative of the ReLu function) $$\eqalign{ \s{\l} = \begin{cases} \o \quad\; {\rm if}\;\;\l>0 \\ 0 \quad\; {\rm otherwise} \\ \end{cases} \\ }$$ and the vector variables $$\eqalign{ x &= Wz &\qiq &dx = dW\:z \\ h &= \s{x} \\ r &= \r{x} &\qiq &dr = h\odot dx \\ y &= BVr &\qiq &dy = B\:dV\,r + BV\,dr \\ a &= A^T &&\{ {\rm a\:column\:vector} \} \\ f &= a:\LR{y\odot y} \\ }$$ It will also be useful to have a diagonal matrix for each vector variable (denoted by the corresponding uppercase letter) to simplify Hadamard products, e.g. $$\eqalign{ H &= \Diag{h} \qiq Hx = h\odot x \\ \\ }$$


Finally, we're ready to calculate the gradients of $f$
$$\eqalign{ df &= a:\LR{2y\odot dy} \\ &= 2Ya:dy \\ &= 2Ya:\LR{B\:dV\,r + BV\,dr} \\ &= 2B^TYar^T:dV \;+\; 2V^TB^TYa:\c{dr} \\ &= 2B^TYar^T:dV \;+\; 2V^TB^TYa:\c{H\,dW\,z} \\ &= 2B^TYar^T:dV \;+\; 2HV^TB^TYaz^T:dW \\ \\ \grad fV &= 2B^TYar^T \qquad\qquad \grad fW = 2HV^TB^TYaz^T \\ }$$

0
On

I myself took the following approach: To obtain the expression for the adjoint operator $\nabla \Phi^*(V,W) \colon \mathbb{R}^{\ell} \to \mathbb{R}^{m \times k} \times \mathbb{R}^{k \times d}$, we let $\Delta_W \in \mathbb{R}^{k \times d}$, $\Delta_V \in \mathbb{R}^{m \times k}$, and $\Delta \in \mathbb{R}^{\ell}$. We can expand $\Phi$ as follows: \begin{equation*} \begin{aligned} \Phi(V + \Delta_V, W) &\approx \Phi(V, W) + B\,\Delta_V\,\sigma_1(W\,{\bf z}),\\ \Phi(V, W + \Delta_W) &\approx \Phi(V, W) + B\,V\,\left(\dot{\sigma}_1(W\,{\bf z}) \odot \Delta_W\,{\bf z}\right), \end{aligned} \end{equation*} where $\odot$ stands for the Hadamard (entry-wise) product. Therefore, the operator $\nabla \Phi(V,W)$ is given by $$(\Delta_V, \Delta_W) \to B\,\Delta_V\,\sigma_1(W\,{\bf z}) + B\,V\,\left(\dot{\sigma}_1(W\,{\bf z}) \odot \Delta_W\,{\bf z}\right).$$ Therefore, we also have \begin{equation} \begin{aligned} \langle \Delta, B\,\Delta_V\,\sigma_1(W\,{\bf z}) \rangle &= \langle \Delta_V, B^T\,\Delta\,\sigma_1({\bf z}^T\,W^T) \rangle,\\ \langle \Delta, B\,V\,\left(\dot{\sigma}_1(W\,{\bf z}) \odot \Delta_W\,{\bf z}\right) \rangle &= \langle \Delta_W, \left(\dot{\sigma}_1(W\,{\bf z}) \odot V^T\,B^T\,\Delta\right){\bf z}^T\rangle. \end{aligned} \end{equation} By the previous identity, the adjoint operator is given by \begin{equation}\label{eq:adjoint} \nabla \Phi^*(V,W) \colon \Delta \to \left(B^T\,\Delta\,\sigma_1({\bf z}^T\,W^T), \left(\dot{\sigma}_1(W\,{\bf z}) \odot V^T\,B^T\,\Delta\right){\bf z}^T \right). \end{equation} Let ${\bf y} = \Phi(V,W) = B\,V\,\sigma_1(W\,{\bf z})$, then $\nabla f({\bf y}) = \textrm{diag}(A)\,\dot{\sigma}_2({\bf y})$ where $\textrm{diag}(A) \in \mathbb{R}^{\ell \times \ell}$ represents the diagonal matrix formed by entries of $A \in \mathbb{R}^{1 \times \ell}$. Thus, \begin{equation}\label{eq:gradient} \begin{aligned} \nabla f(\Phi(V,W)) &= \nabla \Phi^*(V,W)\{\nabla f({\bf y})\} \\ &= \left(B^T\,\textrm{diag}(A)\,\dot{\sigma}_2({\bf y})\,\sigma_1({\bf z}^T\,W^T), \left(\dot{\sigma}_1(W\,{\bf z}) \odot V^T\,B^T\,\textrm{diag}(A)\,\dot{\sigma}_2({\bf y})\right){\bf z}^T \right). \end{aligned} \end{equation}