Is there a theoretical NEED for iterative optimization of an ML model instead of direct solution calculation if batch size= dataset size?

90 Views Asked by At

I stumbled upon a doubt regarding machine learning basics. For any set of parameters $\theta_1,\theta_2,...\theta_n$ that solve an optimization by reducing the loss function J, the ultimate goal is reaching a global minima for J . Making the partial derivatives of J with respect to each parameter $\theta_i $ equal to zero $\partial J/ \partial \theta_{i}=0 $ and solving for $\theta_1,\theta_2,...\theta_n$ and then applying the second derivative test would do the job and give us what we want, right? So from a theoretical point of view, do we NEED to apply iterative optimization techniques?

I picked up the toy example of linear regression. So let's say that I want to fit the best line y=mx+b for a dataset with n pairs of x,y and let's say my loss function is the square error, J(y',y)=(y'-y)2 where y' is our predicted value and y is the real value. (This choice of the loss is not the best choice but that's irrelevant to the current discussion.)

I begin to optimize the model,

$$\partial J/ \partial m =\sum_{i=1}^B 2(y_i-mx_i-b)(-x_i)$$ $$\partial J/ \partial b =\sum_{i=1}^B 2(y_i-mx_i-b)(-1)$$ ,where B is the batch size.

Now theoretically what we aim is $\partial J / \partial m=0 $ and $\partial J / \partial b=0 $ . And I understand the need to optimize the parameters iteratively using $$m_{i+1}=m_{i}-\alpha\ * \partial J / \partial m$$ $$b_{i+1}=b_{i}-\alpha\ * \partial J / \partial b$$ because when we use batching different batches are going to suggest different improvements and thus we need to do this iteratively. Also small changes are suggested so that we don't keep swinging between optimization for different batches and to prevent the possibility that we never even reach optimization( this may happen if the changes are applied at too great a rate ).

But if I set B=n i.e. batch size= dataset size then, setting $\partial J / \partial m=0 $ and $\partial J / \partial b=0 $ seems to be sufficient and eliminates the need for the iterative improvement, right?

If so then solving the following set of equations for our particular linear regression problem should lead us to the solution (universally speaking this statement is not true since we need to do the second derivative test to ensure we are at a "minima" but I would avoid doing so since it's not needed for linear regression and is besides the point):

$$\partial J / \partial m=0 \qquad \partial J / \partial b=0$$

  1. $$\sum_{i=1}^n 2(y_i-mx_i-b)(-x_i) =0 \qquad \sum_{i=1}^n 2(y_i-mx_i-b)(-x_i) =0$$
  2. $$2m\sum_{i=1}^n x_i^2 +2b\sum_{i=1}^nx_{i}-2\sum_{i=1} y_ix_i =0 \qquad 2bn+2m \sum_{i=1}^nx_{i}-2\sum_{i=1}^ny_{i}=0$$
  3. $$m\sum_{i=1}^n x_i^2 +b\sum_{i=1}^nx_{i}=\sum_{i=1} y_ix_i \qquad bn+m \sum_{i=1}^nx_{i}=\sum_{i=1}^ny_{i}$$
  4. $$m=\frac {\sum_{i=1}^ny_i \sum_{i=1}^nx_i-n\sum_{i=1}^nx_iy_i}{(\sum_{i=1}^nx_i)^2-n\sum_{i=1}^nx_i^2} \qquad b=\frac{\sum_{i=1}^nx_iy_i\sum_{i=1}^nx_i-\sum_{i=1}^nx_i^2\sum_{i=1}^ny}{(\sum_{i=1}^nx_i)^2-n\sum_{i=1}^nx_i^2} $$

And the maths checks out, ("calculated line" below is derived from above equations and "learned line" is derived from linear regression model)

So then I gave it a lot of thought to understand why is the above approach not applied and I could only figure out the following reasons:

  1. For any given model with parameters $\theta_1,\theta_2,...\theta_n$ , we will have to solve the following set of equations for the defined loss function J: $$\partial J/ \partial \theta_1 =0$$ $$\partial J/ \partial \theta_2 =0$$ $$...$$ $$\partial J/ \partial \theta_n =0$$ and then also apply the second derivative test to filter out the minimas, maximas and saddle points.

    And for each model we will have to derive the solutions by maths every time which would be extremely difficult and would make all of machine learning less automatable . So they use chain rule instead where instead of deciphering the maths completely , they sort of "localize" the corrections for each parameter in the network and every parameter learns on the basis of it's input, the function it performs on the input and the output it produces and the backpropagated error in this output.

    Chain rule actually makes applying maths itself sort of breakable into pieces and we can just plug and run mathematical machines (like convolutions,linear layers,etc) because of it.

  2. It would very often be extremely expensive to set the batch size=dataset size.

So I guess it's pretty much automation and ease that motivates methods of optimization like gradient descent, Newton Gauss methods ,etc.

So my question ultimately would be is it just automatability and practicality that leads us to use iterative optimization?

If there is any other fundamental reason why straight up first order derivation is not nullified other than what I mentioned above please do let me know.

Thanks a ton!