Derivation of the attention block of a transformer

325 Views Asked by At

I have a function that I want to derive.

Let $x,y \in \mathbb{R}^{n \times m}$ and $A,B \in \mathbb{R}^{k \times n}$, let $\lambda \in \mathbb{R}$ and $\operatorname{softmax} (\cdot)$ be the regular softmax-function it is applied on the rows of the matrix, so

$$ \operatorname{softmax} (a_{i,j}) = \frac{\exp(a_{i,j})}{\sum\limits_{j=1}^{n} \exp(a_{i,j})}$$

and $C$ is a matrix. The function I want to derive is

$$f(x,y) = \operatorname{softmax} \left( \lambda \cdot x^T \cdot A^T \cdot B \cdot y \right) \cdot C \cdot x$$

I am looking for $\frac{\delta f}{\delta x_i}$ and $\frac{\delta f}{\delta y_j}$ where $x_i,y_j$ are the $i$-th and $j$-th column of $x$ and $y$ respectively. I didn't study derivations of these matrix-valued functions, can someone help me? Thanks!

1

There are 1 best solutions below

0
On BEST ANSWER

$ \def\R#1{{\mathbb R}^{#1}} \def\l{\lambda} \def\s{\sigma} \def\e{\varepsilon} \def\o{{\tt1}} \def\d{\delta} \def\I{{\cal I}} \def\H{{\large\cal H}} \def\G{{\cal G}} \def\BR#1{\left[#1\right]} \def\LR#1{\left(#1\right)} \def\op#1{\operatorname{#1}} \def\vc#1{\op{vec}\LR{#1}} \def\diag#1{\op{diag}\LR{#1}} \def\Diag#1{\op{Diag}\LR{#1}} \def\trace#1{\op{Tr}\LR{#1}} \def\frob#1{\left\| #1 \right\|_F} \def\qiq{\quad\implies\quad} \def\qif{\quad\iff\quad} \def\p{\partial} \def\grad#1#2{\frac{\p #1}{\p #2}} \def\c#1{\color{red}{#1}} \def\CLR#1{\c{\LR{#1}}} \def\fracLR#1#2{\LR{\frac{#1}{#2}}} \def\gradLR#1#2{\LR{\grad{#1}{#2}}} \def\m#1{\left[\begin{array}{r}#1\end{array}\right]} $Let's use uppercase letters for matrices, lowercase for vectors, and Greek for scalars. The only effect this has on the problem statement is to rename $\{x,y\}\to\{X,Y\}$

Let {$e_k\in\R{m}$} denote the basis vectors, which can be used to access the individual columns of a matrix, e.g. $$\eqalign{ G = \m{g_1&g_2&\cdots&g_m} \qiq g_k=Ge_k }$$ Similarly, the rows can be accessed as $\:G^Te_k$

So the matrix can be expanded either by rows or by columns $$\eqalign{ G \;=\; \sum_{k=1}^m\LR{Ge_k}e_k^T \;=\; \sum_{j=1}^m\:e_j\LR{G^Te_j}^T \\ }$$

The elementwise application of a function to the rows of $G$ can be expressed as $$\eqalign{ F = \sum_{j=1}^m\:e_j\:f\LR{G^Te_j}^T \\ }$$ The derivative of the elementwise Softmax function is well known $$\eqalign{ s &= softmax(w) &\;\;\;\qiq S = \Diag s \\ ds &= \LR{S-ss^T}\,dw \\ }$$ Apply this to the row-wise function $$\eqalign{ s_j &= softmax\LR{G^Te_j}\qiq S_j=\Diag{s_j} \\ ds_j &= \LR{S_j-s_js_j^T}\LR{dG^Te_j} \\ \\ dF &= \sum_{j=1}^m\:e_j\,\BR{\LR{S_j-s_js_j^T}dG^Te_j}^T \\ &= \sum_{j=1}^m\:e_je_j^T\: dG \LR{S_j-s_js_j^T} \\ }$$ First, recall the componentwise self-gradient of a matrix
($\:\c{\e_\ell}$ are the basis vectors for $\R{n}\,$) $$\eqalign{ \grad{Y}{Y_{\ell k}} &= E_{\ell k} \;\doteq\; \c{\e_\ell} e_k^T \\ }$$ Then try the simple case, i.e. the gradient with respect to $Y$ $$\eqalign{ G &= \l X^TA^TBY \;\in\R{m\times m} \\ dG &= \l X^TA^TB\:dY \\ }$$ and substitute into the objective function $$\eqalign{ P &= FCX \;\in\R{m\times m} \\ dP&= dF\LR{CX} \\ &= \sum_{j=1}^m\:e_je_j^T\: \c{dG} \LR{S_j-s_js_j^T}CX \\ &= \sum_{j=1}^m\:e_je_j^T\CLR{\l X^TA^TB\:dY}\LR{S_j-s_js_j^T}CX \\ \grad{P}{Y_{\ell k}} &= \sum_{j=1}^m\:e_je_j^T\LR{\l X^TA^TB\:{E_{\ell k}}}\LR{S_j-s_js_j^T}CX \\ \grad{P_{im}}{Y_{\ell k}} &= \sum_{j=1}^m\:\c{e_i^Te_j}e_j^T\LR{\l X^TA^TB\:{E_{\ell k}}}\LR{S_j-s_js_j^T}CXe_m \\ &= \sum_{j=1}^m\:\c{\d_{ij}}e_j^T\LR{\l X^TA^TB\:{E_{\ell k}}}\LR{S_j-s_js_j^T}CXe_m \\ &= e_i^T\LR{\l X^TA^TB\:{E_{\ell k}}}\LR{S_i-s_is_i^T}CXe_m \\ &= \l \LR{x_i^TA^Tb_\ell}\;e_k^T\LR{S_i-s_is_i^T}Cx_m \\ &= \l \LR{b_\ell^TAx_i}\;e_k^T\LR{S_i-s_is_i^T}Cx_m \\ }$$ where $\{b_\ell,x_m\}$ are the columns of $\{B,X\}$ respectively.

Now try the gradient wrt $X.\:$ The difference from the previous calculation is that $$\eqalign{ dG = \l\,dX^T\,A^TBY \\ }$$ and the self-gradient of the transpose of $X$ is $$\eqalign{ \grad{X^T}{X_{\ell k}} &= E_{k\ell} \;\doteq\; e_k\,\e_\ell^T \\ }$$ and the differential of $P$ is $$\eqalign{ dP &= dF\,\LR{CX} \;+\; \LR{FC}\,dX \\ }$$ A similar calculation leads to the following result $$\eqalign{ \grad{P_{im}}{X_{\ell k}} &= \l \BR{a_\ell^TBY\LR{S_k-s_ks_k^T}Cx_m} \d_{ik} \;+\; \LR{e_i^TFc_\ell} \d_{mk} \\ }$$ where $\{a_\ell,c_\ell\}$ are the columns of $\{A,C\}$