Optimizing computational-complexity of linear-algebraic operations

35 Views Asked by At

This question is from Stanford CS221 Autumn 2019-2020, Foundations, Problem 2.d

Consider the scalar-valued function $f(w)$ defined as follows:

$$ f(\mathbf{w})=\sum_{i=1}^{n}\sum_{j=1}^{n}(\mathbf{a}_{i}^\intercal \mathbf{w}−\mathbf{b}_{j}^\intercal \mathbf{w})^2+\lambda \lVert \mathbf{w} \rVert_2^2 $$

where, $\mathbf{w} \in \mathbb{R}^d$ (represented as a column vector) and constants $\mathbf{a}_i,\mathbf{b}_j \in \mathbb{R}^d$ (also represented as column vectors) and $\lambda \in \mathbb{R}$,

Devise a strategy that first does preprocessing in $O(nd^2)$ time, and then for any given vector $\mathbf{w}$, takes $O(d^2)$ time instead to compute $f(w)$. [Hint: Refactor the algebraic expression; this is a classic trick used in machine learning. Again, you may find it helpful to work out the scalar case first.]

My Solution: I do not know what the "classic trick" is! But, I did the following:

$$ \begin{align*} \sum_{i=1}^{n}\sum_{j=1}^{n}(\mathbf{a}_{i}^\intercal \mathbf{w}−\mathbf{b}_{j}^\intercal \mathbf{w})^2 & = \sum_{i=1}^{n}\sum_{j=1}^{n}\biggl[(\mathbf{a}_{i}^\intercal \mathbf{w})^2 + (\mathbf{b}_{j}^\intercal \mathbf{w})^2 - 2 (\mathbf{a}_{i}^\intercal \mathbf{w})(\mathbf{b}_{j}^\intercal \mathbf{w})\biggr] \\ & =\sum_{i=1}^{n}\sum_{j=1}^{n}(\mathbf{a}_{i}^\intercal \mathbf{w})^2 + \sum_{i=1}^{n}\sum_{j=1}^{n}(\mathbf{b}_{j}^\intercal \mathbf{w})^2 - 2\sum_{i=1}^{n}\sum_{j=1}^{n}(\mathbf{a}_{i}^\intercal \mathbf{w})(\mathbf{b}_{j}^\intercal \mathbf{w})\\ &=n\sum_{i=1}^{n}(\mathbf{a}_{i}^\intercal \mathbf{w})^2 + n\sum_{j=1}^{n}(\mathbf{b}_{j}^\intercal \mathbf{w})^2 - 2\sum_{i=1}^{n}(\mathbf{a}_{i}^\intercal \mathbf{w})\sum_{j=1}^{n}(\mathbf{b}_{j}^\intercal \mathbf{w}) & \text{eq.1}\\ \end{align*} $$

In the above, eq.1, we have:

  • Computing $\mathbf{a}_{i}^\intercal \mathbf{w}$ takes time $O(d)$ for each $i=1,2, \dots,n$; therefore, computing the first-term takes time $O(nd)$
  • Computing $\mathbf{b}_{j}^\intercal \mathbf{w}$ takes time $O(d)$ for each $j=1,2,\dots,n$; therefore, computing the second-term takes time $O(nd)$
  • Similarly, from the above two points we can conclude that the third-term, also takes time $O(nd)$ - computing each summation individually and then taking the product of their results.
  • Expressing $\lVert \mathbf{w} \rVert_2^2$ as $\mathbf{w}^\intercal\mathbf{w}$, we can compute the L2-norm in time $O(d)$.

So, the total running time is $O(nd)$! (Am I missing something? and Also, how to solve the above problem?)

1

There are 1 best solutions below

0
On BEST ANSWER

In your solution, $O(nd)$ is what you have to pay for every new vector $\mathbf{w}$ with your solution: if you want to evaluate $f$ on $m \gg n$ inputs, then your cost scales as $m n d$ instead of the $md^2$ asked. (I am assuming here $d \ll n$, otherwise your solution is indeed faster).

To obtain what is expected, rewrite $$\begin{align*} \sum_{i=1}^{n}\sum_{j=1}^{n}(\mathbf{a}_{i}^\intercal \mathbf{w}−\mathbf{b}_{j}^\intercal \mathbf{w})^2 &= \sum_{i=1}^{n}\sum_{j=1}^{n} \sum_{k=1}^d\sum_{\ell=1}^d(\mathbf{a}_{i,k}−\mathbf{b}_{j,k})\mathbf{w}_k(\mathbf{a}_{i,\ell}−\mathbf{b}_{j,\ell})\mathbf{w}_\ell \\ &= \sum_{k=1}^d\sum_{\ell=1}^d \underbrace{\left( \sum_{i=1}^{n}\sum_{j=1}^{n}(\mathbf{a}_{i,k}−\mathbf{b}_{j,k})(\mathbf{a}_{i,\ell}−\mathbf{b}_{j,\ell}) \right)}_{\gamma_{k,\ell}}\mathbf{w}_k\mathbf{w}_\ell \end{align*}$$ Now, using the same idea as you have, show that you can compute each $\gamma_{k,\ell}$ (which do not depend on $\mathbf{w}$!) ahead of time in time $O(n)$. Thus, computing all of them takes preprocessing time $O(nd^2)$.

Once you have that, given a new $\mathbf{w}$ you only have to compute $$ f(\mathbf{w}) = \sum_{k=1}^d\sum_{\ell=1}^d \gamma_{k,\ell} \mathbf{w}_k\mathbf{w}_\ell + \lambda\sum_{k=1}^d \mathbf{w}_k^2 $$ which takes time $O(d^2+d)=O(d^2)$.