Consider the self-attention layer of a transformer with a single attention head, which performs the computation $\textbf{Z} = $ Softmax($\textbf{QK}^\top)\textbf{V}$, where $\textbf{Q} \in \mathbb{R}^{L \times D}$, $\textbf{K} \in \mathbb{R}^{L \times D}$ and $\textbf{V} \in \mathbb{R}^{L \times D}$, for a sequence length $L$ and hidden size $D$. We will show that by suitably approximating the softmax transformation the computational complexity above can be reduced. Consider the MacLaurin series expansion for the exponential function,
\begin{equation} e^t = \sum_{n=0}^{\infty} \frac{t^n}{n!} = 1 + t + \frac{t^2}{2!} + \frac{t^3}{3!} + \ldots \end{equation}
Using only the first three terms in this series, we obtain $\exp(t) \approx 1 + t + \frac{t^2}{2}$. Use this approximation to obtain a feature map $\phi: \mathbb{R}^D \rightarrow \mathbb{R}^M$ such that, for arbitrary $\textbf{q} \in \mathbb{R}^D$ and $\textbf{k} \in \mathbb{R}^D$, we have $\exp(\textbf{q}^\top \textbf{k}) \approx \phi(\textbf{q})^\top \phi(\textbf{k})$.
I already obtained the following feature map:
\begin{equation} \phi(\mathbf{x}) = \left(1, x_1, \ldots, x_D, x_1^2/\sqrt{2}, \ldots, x_D^2/\sqrt{2}, x_1x_2,x_1x_3, \ldots,x_2x_3,x_2x_4,\ldots, x_{D-1}x_D\right). \end{equation}
Now I am asked to do the following:
Using the approximation $\exp(\textbf{q}^{\top}k) \approx \phi(\textbf{q})^\top\phi(\textbf{k})$, and denoting by $\Phi(\textbf{Q}) \in \mathbb{R}^{L \times M}$ and $\Phi(\textbf{K}) \in \mathbb{R}^{L \times M}$ the matrices whose rows are (respectively) $\phi(\textbf{q}_i)$ and $\phi(\textbf{k}_i)$, where $\textbf{q}_i$ and $\textbf{k}_i$ denote (also respectively) the $i$-th rows of the original matrices $\textbf{Q}$ and $\textbf{K}$, show that the self-attention operation can be approximated as $\textbf{Z} \approx \textbf{D}^{-1} \Phi(\textbf{Q}) \Phi(\textbf{K})^{\top} \textbf{V}$, where $\textbf{D} = \textbf{Diag}(\Phi(\textbf{Q}) \Phi(\textbf{K}) \mathbf{\textbf{1}}_L)$ (here, $\textbf{Diag}(\textbf{v})$ denotes a matrix with the entries of vector $\textbf{v}$ in the diagonal, and $\mathbf{\textbf{1}}_L$ denotes a vector of ones with size $L$).
How can I approach this problem in order to solve it analytically? I started by doing the matrix multiplications one by one in $\textbf{D}^{-1} \Phi(\textbf{Q}) \Phi(\textbf{K})^{\top} \textbf{V}$ but the math becomes really difficult, so I don't know it that's the correct approach.
We begin by noting that the softmax matrix input will be:
\begin{equation} (\textbf{q}^T\textbf{k}) = \begin{bmatrix} q_1^Tk_1 & \dots & q_1^Tk_L \\ \vdots & \ddots & \vdots \\ q_L^Tk_1 & \dots & q_L^Tk_L \end{bmatrix} \label{eq:q_k} \end{equation}
Since Softmax is given by the expression:
\begin{equation} \text{Softmax}(z)_i = \frac{e^{z_i}}{\sum_{j=1}^ke^{z_j}} \label{softmax_expression} \end{equation}
We can infer that computing for each row $i$ and column $j$:
\begin{equation} \text{Softmax}(\textbf{q}^T\textbf{k}) = \frac{\exp(q_i^Tk_j)}{\sum_{m=1}^L\exp(q_i^Tk_m)} \end{equation}
By using the approximation , we get that:
\begin{equation} \frac{\exp(q_i^Tk_j)}{\sum_{m=1}^L\exp(q_i^Tk_m)} \approx \frac{\Phi(q_i)^T\Phi(k_j)}{\sum_{m=1}^L\Phi(q_i)^T\Phi(k_m)} \label{eq:approximation} \end{equation}
Now, using the matrices $\Phi(\textbf{Q})$ and $\Phi(\textbf{K})$ provided, we can write the numerator of the approximation for all pairs of $i$ and $j$ as:
\begin{equation} \Phi(\textbf{Q})\Phi(\textbf{K})^T\textbf{V} \end{equation}
For the denominator of the approximation, notice that for any given row $i$ column $j$, the can write it as:
\begin{equation} \textbf{D} = \text{Diag}\Bigg( \sum_{m=1}^L\phi(q_i)^T\phi(k_m) \dots \sum_{m=1}^L\phi(q_L)^T\phi(k_m)\Bigg) \end{equation}
Using the provided matrices $\Phi(\textbf{Q})$ and $\Phi(\textbf{K})$, we can simplify this expression to:
\begin{equation} \textbf{D} = \text{Diag}\Big(\Phi(\textbf{Q})\Phi(\textbf{K})^T \textbf{1}_L\Big) \end{equation}
Therefore, we can conclude that:
\begin{equation} \textbf{Z} \approx \textbf{D}^{-1}\Phi(\textbf{Q})\Phi(\textbf{K}^T)V \end{equation}