What does 'overfitting' (machine learning field) mean in linear algebra context?

353 Views Asked by At

We know that over fitting the training data leads to memorizing the training set, thus acting poor in the test samples. But when we here about this term, what does over-fitting means in the context of linear algebra, knowing that ML objective is solving Ax = b.

1

There are 1 best solutions below

0
On BEST ANSWER

TL;DR: If you construct your design matrix $A$ in a way that it perfectly interpolates your data.

Here is some intuition, consider the following simple regression problem:

You have a set of $N$ 2D-points $\{x, y \}_i^N$ and you like to find a function which interpolates these points. We can simply try a linear function:

$$ \begin{bmatrix}1&x_1\\\vdots&\vdots\\1&x_N\end{bmatrix} \begin{bmatrix}\theta_1\\\vdots\\\theta_N\end{bmatrix} = \begin{bmatrix}y_1\\\vdots\\y_N\end{bmatrix}$$ $$X\theta = y \text{ or } Ax=b$$

We call $X$ (or $A$) the design matrix and $\theta$ (or $x$) coefficients / free parameters. Now if $N$ is 2, we will have two points which can be perfectly fitted by a line. 2 points, two unknowns, it all comes together. We can also try a polynomial of order 2:

$$ \begin{bmatrix}1&x_1&x_1^2\\\vdots&\vdots&\vdots\\1&x_N&x_N^2\end{bmatrix} \begin{bmatrix}\theta_1\\\vdots\\\theta_N\end{bmatrix} = \begin{bmatrix}y_1\\\vdots\\y_N\end{bmatrix}$$

Observe we need at least 3 points $\{x,y\}$ now to calculate a function which perfectly fits them. In general we have to distinguish 3 cases, where $p$ is the order of our polynomials:

  • N < p+1: underdetermined system of equations, this cannot be solved. Basically not enough points.
  • N = p+1: (given our matrix $X$ is invertible) solvable system, we can fit our points perfectly, the function will pass through every point: $\theta_1 + \theta_2 x_j + \theta_3 x_j^2 = y_j$
  • N > p+1: overdetermined system of equations, also known as linear regression. We have more points than we actually would need to fit our data perfectly. This can be solved by normal equations: $\theta = (X^TX)^{-1}X^Ty$, a formula you might be familiar with if you have some background in machine learning. Our function will now pass though some points, but certainly not through all. For those points we will have a slightly different $y$-value, $\hat{y}$.

To address your initial question, as you wrote "overfitting refers to memorizing the training set". If we have $N$ training data points and we evaluate our model (the function parametrized by coefficients $\theta$ we found earlier) on the loss-function $\mathcal{L}(\hat{y},y) = \frac{1}{N}\sum_i^N||\hat{y} - y||^2_2 $, also kown as Mean-Squared-Error, we can achieve a train-loss of 0 if we chose the degree of the polynomial $p$ to be $N-1$. However if you get a new set of points $\{\tilde{x}, \tilde{y} \}$ (the test set) it is highly unlikely that your function will pass though these points as well and your test-loss $\mathcal{L}(\hat{\tilde{y}},\tilde{y})$ will be much higher than your train-loss. We overparameterized our model, this is called overfitting.

For more explanations have a look at Bias-variance tradeoff & Understanding the Bias-Variance Tradeoff.