How does one rigorously prove that gradient descent indeed decreases the function in question locally i.e. show $f(x^{(t+1)}) \leq f(x^{(t)})$?

1.2k Views Asked by At

How does one prove that gradient descent indeed decreases the function in question locally? In other words if we take a step in the negative of the gradient as in:

$$ x^{(t+1)} = x^{(t)} - \eta \nabla f( x^{(t)} )$$

how does one show rigorously that (for a correctly chosen step size, or sufficiently small step size) we have:

$$ f(x^{(t+1)}) \leq f( x^{(t)} )$$

I am not sure what type of assumptions one needs for such a mathematical statement to hold but it seems obvious to me that it should be true for (locally) convex functions.

I don't know how to attempt this rigorously but this is what I have tried (its a hand wavy but sort of correct attempt with the mathematics that I remember):

Recall the taylor approximation for a (single variable) function:

$$ f(x + \epsilon) = f(x) + \epsilon f'(x) + \frac{ \epsilon^2 }{2} f''(x)+ \frac{ \epsilon^3 }{3!} f'''(x) + \cdots $$

$$ f(x + \epsilon) = \left[ f(x) + \epsilon f'(x) \right] + \left[ \frac{ \epsilon^2 }{2} f''(x)+ \frac{ \epsilon^3 }{3!} f'''(x) + \cdots \right] = Linear(x,\epsilon)+Err(x,\epsilon)$$

where $ Linear(x,\epsilon) = f(x) + \epsilon f'(x) $ is the linear approximation for $f$ for small $\epsilon$ and $Err(x,\epsilon) = \frac{ \epsilon^2 }{2} f''(x) + \frac{ \epsilon^3 }{3!} f'''(x) + \cdots$ is the error term and its about $ Err(x,\epsilon) = O( \epsilon^2 f''(x) )$ for really small $\epsilon$ (in this case $\epsilon^2$ is the leading term).

Say that we allow $\epsilon = - \eta f'(x)$. In that case we have:

$$ f(x - \eta f'(x)) = f(x) - \eta f'(x)^2 + \frac{ \eta^2 {f'(x)}^2 }{2} f''(x) + \cdots $$

intuitively what we'd want is to choose $\eta$ such that the largest negative term (i.e. $\eta f'(x)^2$ term) is large enough in magnitude to be larger than $Err(x,\epsilon)$ and therefore, actually decrease the function. I think if we have $ \eta f'(x)^2 $ be larger than the leading term in the error, then the function $f$ should decrease (since the linear part which is the largest decreases). So we have:

$$ - \eta f'(x)^2 + \frac{ \eta^2 {f'(x)}^2 }{2} f''(x) \leq 0 \iff \eta f'(x)^2 \geq \frac{ \eta^2 {f'(x)}^2 }{2} f''(x)$$

if we restrict ourselves to choose $\eta > 0 $ (to guarantee a step size in the direction that actually decreases the function i.e. following the negative of the gradient) we have:

$$1 \geq \frac{ \eta }{2} f''(x) \implies \frac{2}{f''(x)} \geq\eta > 0$$

I think the above step size makes sure the function decreases $f(x^{(t+1)} \leq f(x^{(t)}$.

Consider ignoring the terms from the error that are too small ($\epsilon$'s with power greater than 2)

$$ f(x^{(t+1)}) = f(x - \eta f'(x)) \approx f(x) - \eta f'(x)^2 + \frac{ \eta^2 {f'(x)}^2 }{2} f''(x) $$

where we choose $\eta$ such that $ - \eta f'(x)^2 + \frac{ \eta^2 {f'(x)}^2 }{2} f''(x) \leq 0$ was true. Therefore we must have:

$$ f(x^{(t+1)}) \approx f(x) - \eta f'(x)^2 + \frac{ \eta^2 {f'(x)}^2 }{2} f''(x) = f(x) + ( \text{term } \leq 0 ) $$

So since the term $- \eta f'(x)^2 + \frac{ \eta^2 {f'(x)}^2 }{2} f''(x) = \frac{ \eta^2 {f'(x)}^2 }{2} f''(x) - \eta f'(x)^2 \leq 0 $. The function $f$ must be decreasing (or staying the same) at each gradient descent step. i.e. $f(x^{(t+1)}) \leq f(x^{(t)})$.

I know the argument I presented is not a real proof, but I was wondering if someone that knew more rigorous analysis or optimization could help me polish of the rough corners of this nearly correct proof.

[Note that the proof is actually rigorous for the case that $f(x)$ has a first and second derivative and derivatives of higher order are zero. An example of such function is a quadratic function]

[Feel free to generalize my attempt to multivariable functions]

2

There are 2 best solutions below

0
On

an extensive proof can be found here, although some tiny details are missing, hence I would like to put this link up as well, which is the entire proof that I have put up on my github. Since the answer is way too long, and it would become an unnecessary double-writing, I decided to attach the links.

0
On

Consider any scalar valued function $f:X\subseteq \mathbf{R}^2 \to \mathbf{R}$. You could think of $f$ as the temperature or pressure function. The partial derivative $\frac{\partial f}{\partial x} (a,b)$ is the slope at the point $(a,b,f(a,b))$ otained as the intersection of the surface $z=f(x,y)$ with the plane $y=b$. The other partial derivative $\frac{\partial f}{\partial y} (a,b)$ has a similar geometric interpretation. If $P(a,b,f(a,b))$ is an arbitrary point on this surface, you could draw an infinite number of curves passing through it, whose slope we may choose to measure.

As you perhaps know, we define the directional derivative of a function to be the change in the functional value $f(x,y)$ for a small change in the direction $\mathbf{d}$.

\begin{align*} \frac{\partial f}{\partial \mathbf{d}} = \lim_{\alpha \to 0} \frac{f(\mathbf{a}+\alpha\mathbf{d})-f(\mathbf{a})}{\alpha} \end{align*}

where $\mathbf{a} = (a,b)$.

But, the coolest thing is, if you know the gradient of a function and a direction $\mathbf{d}$, you can easily compute the directional derivative.

Define the single variable function

$\phi(\alpha) = f(\mathbf{a} + \alpha \mathbf{d})$

The definition of the directional derivative becomes:

\begin{align*} \frac{\partial f}{\partial \mathbf{d}} &= \lim_{\alpha \to 0} \frac{f(\mathbf{a}+\alpha\mathbf{d})-f(\mathbf{a})}{\alpha}\\ &= \lim_{\alpha \to 0} \frac{\phi(\alpha) - \phi(0)}{\alpha} = \phi'(0) \end{align*}

Thus,

\begin{align*} \frac{d\phi(\alpha)}{d\alpha} &= Df(\mathbf{a} + \alpha\mathbf{d})\cdot \mathbf{d} \quad \{ \text{By chain rule} \} \\ &= \nabla f(\mathbf{a}+\alpha \mathbf{d}) \cdot \mathbf{d} \end{align*}

So,

\begin{align*} \phi'(0) = \frac{\partial f}{\partial \mathbf{d}} = \nabla f(\mathbf{a}) \cdot \mathbf{d} \end{align*}

Consequently, if $\mathbf{d}$ is opposite to the gradient vector $\nabla f(\mathbf{a})$, the functional value decreases by maximum amount. It's given by $- \lvert \lvert \nabla f(a) \rvert \rvert ^2$.

Thus, gradient descent indeed decreases the value of the function.