Relationship between model variance and dataset size for SGD

17 Views Asked by At

I'm looking for some function that describes the variance of model weights when trained with Stochastic Gradient Descent for $m$ independent minibatches.

I can apply central limit theorem to a single minibatch of SGD training to get the result that the gradient estimate (and therefore the updated model) has variance of $O\left(\frac{1}{n}\right)$, where $n$ is the minibatch size. However, I am having trouble showing the variance of model weights produced by multiple sequential minibatches, since this would be dependent on how the model changes between each minibatch.

Additionally, if we say that the model doesn't change at all (which, obviously, is not true anyway), then the variance would be be $O(\frac{m}{n})$, which is not useful because the model will converge (at least for many types of function), rather than have increasing variance, after many minibatches. I'm looking for a proof that says something like "The variance on model weights from SGD is $O\left(\frac{1}{mn}\right)$" (or some other function that is decreasing with $m$ and $n$).

I figure that something like this must exist for strongly convex functions (and perhaps even with weaker requirements), but I haven't been able to find anything so far. Does any such proof exist?