HiveBrain v1.2.0
Get Started
← Back to all entries
gotchaMinor

Gradient descent overshoot - why does it diverge?

Submitted by: @import:stackexchange-cs··
0
Viewed 0 times
divergewhygradientdoesovershootdescent

Problem

I'm thinking about gradient descent, but I don't get it.

I understand that it can overshoot the minimum when the learning rate is too large. But I can't understand why it would diverge.

Let's say we have

$$J(\theta_0, \theta_1) = \frac{1}{2m}\sum_{i=1}^m (h_\theta(x^i)-y^i)^2$$

$$\theta_1 := \theta_1-\alpha\frac{\partial}{\partial\theta_1}J(\theta_1)$$

When the slope is negative the cost will converge to the minimum from the left of the graph, as $\theta_1$ will increase.

When the slope is positive it will converge from the right of the graph, as $\theta_1$ will decrease.

Now it might overshoot when the learning rate $\alpha$ is too large.

In that case it should overshoot again.

But not by much, should it not circle around the minimum? Why would it diverge?

Solution

I think it's useful to understand first why gradient descent converges. We assume that our function has $L$-Lipshitz gradient:

$$\|\nabla f(x) - \nabla f(y)\| \le L \|x - y\|,$$

or, in 1-dimensional case:

$$|f'(x) - f'(y)| \le L |x-y|$$

The main tool in non-convex optimization is the Descent Lemma:
$$f(y) \le f(x) + \langle \nabla f(x), y - x \rangle + \frac L2 \|y - x\|^2$$
or, in 1-dimensional case:
$$f(y) \le f(x) + f'(x) (y - x) + \frac L2 (y - x)^2$$
(Descent Lemma makes an intuitive sense because of Taylor expantion and since $L$ is an upper bound on $|f''(x)|$ by Mean-Value theorem applied to $f'$)

Now, let's make a gradient descent step: $y \gets x - \gamma \nabla f(x)$. Substituting this into descent Lemma, we have:

\begin{align}
f(y)
&\le f(x) + \langle \nabla f(x), -\gamma \nabla f(x) \rangle + \frac L2 \|\gamma \nabla f(x)\|^2 \\
&= f(x) - \gamma \|\nabla f(x)\|^2 + \frac {L \gamma^2}2 \|\nabla f(x)\|^2 \\
&= f(x) - \gamma(1 - \frac {L \gamma} 2) \|\nabla f(x)\|^2
\end{align}

Therefore, when $\gamma \le \frac 2L$, we have $f(y) \le f(x)$. Something like $\gamma = \frac 1L$ suffices for convergence of canonical gradient descent.

But what if $\gamma > \frac 2L$? The thing is, there are cases when Descent Lemma is tight: for example, when $f$ is quadratic. Let $f(x) = x^2$. It has $f'(x) = 2x$ and $L=2$ (since $|2x - 2y| \le 2 |x-y|$). Then, by selecting $\gamma > \frac 2L = 1$, we have:

$$(x - \gamma 2 x)^2 = (2 \gamma - 1)^2 x^2 > (2 - 1)^2 x^2 > x^2,$$

I.e. gradient descent diverges.

Context

StackExchange Computer Science Q#54541, answer score: 4

Revisions (0)

No revisions yet.