Can gradient descent be used to find value of exponent?

410 Views Asked by At

I'm experimenting with machine learning and I'm trying to develop a model that'll find the exponent that the input will need to be raised to in order to result in the output. For example, if input=$[0, 1, 2, 3]$ and output=$[0, 1, 8, 27]$ then the exponent is $3$.

The loss function I'm using is $L(g)=(k^g-k^3)^2$ where $g$ is the model's current guess. I found the derivative of this function to be $L'(g)=2(k^g-k^3)\cdot k^g \cdot \ln(k)$

The guess is then bettered by subtracting its derivative multiplied by the learning rate. I.e: $g_{n+1}=g_n-r\cdot L'(g_n)$ for each $k$ in the training data for some number of training cycles.

The problem I found is that even when $g$ is close to $3$, the derivative of the loss function is too extreme and ends up missing the zero as seen in the picture:

The above picture is the graph of $r\cdot L'(g)$ where $r=0.0001$. It seems like for any $g$ even considerably greater than $3$, the gradient blows up and ends up shooting the next guess way too far left. I'm already giving up on the idea of having a constant learning rate. I tried on basing the learning rate on the loss function so that the lesser the error, the lesser the learning rate, and the less chance it'll miss the zero. However, that it did not work at all, and I'm wondering if gradient descent can be used at all to solve this problem. Thank you in advance.

1

There are 1 best solutions below

0
On

I'm a bit confused by this question because it seems like you presuppose you know the exponent in advance. I would have thought you'd want to be given a batch of data and try to discover which exponent best fit the data. Nevertheless, the comments indicate you've been working in a different direction, though not yet posted any new updates, so I won't look into that.

Instead, I wanted to consider why your loss function is so unstable. The first thing to try is to use Stability theory, which is usually studied in the context of non-linear dynamical systems. More precisely we can use linear stability analysis to see what the behaviour of the loss function looks like around the fixed point.

Assume we evaluate the loss for a single value $k$ (presumably you are using batches in your implementation though). Let $$L(g_t) = \frac{1}{2}(k^{g_t} - k^\alpha)^2$$ be the loss function. The gradient descent algorithm ($g_t = g_{t-1} - \eta\, \partial_g L(g_t)$) can be viewed as solving the following differential equation $$ \partial_t g_t = - \partial_g L(g_t) = v(g_t) $$ where I'm writing $g_t = g(t)$ to be the value of $g$ over time. You can see this by writing $\eta=\Delta t$ and realizing that $ (g_t - g_{t-1})/\Delta t \rightarrow \partial_t g_t $ as $\Delta t \rightarrow 0$. Note that $v$ is "velocity" of $g$. Then \begin{align*} v(g_t) &= -\partial_g L(g_t) = -k^{g_t} \ln(k) (k^{g_t} - k^\alpha) \\ v'(g_t) &= -\partial_{gg} L(g_t) = -k^{g_t} \ln(k)^2 (2k^{g_t} - k^\alpha) =: J(g_t) \end{align*} Ok, so when $g_t\approx \alpha$, we get $v(g_t)\approx J(\alpha)[g_t - \alpha]$, meaning if $y_t = g_t - \alpha$, then $y_t = C e^{J(\alpha)} y_t$. In other words, the system is stable when $y_t \rightarrow 0 $ (so $g_t\rightarrow \alpha$), which happens if $J(\alpha) < 0$. But $ J(\alpha) = -k^{2\alpha} \ln(k)^2 < 0 $, so indeed the system is stable near the optimal point (i.e., $g_t=\alpha$).

Huh, so if you get close enough to the solution you should be ok. This suggests to me that the problem is in the numerics. Perhaps the error induced by using Forward Euler is the problem? Recall that the local truncation error (up to third-order in $\eta = r$) is given by $$ \Delta_{LTE}(\delta) = \frac{1}{2} \eta^2 \partial_{tt} g_t = \frac{1}{2} \eta^2 J(\alpha + \delta) v(\alpha + \delta) $$ where we are evaluating the function at $g_t = \alpha + \delta$ (i.e., $\delta$ is how far the starting point for the forward Euler step is from the optimum $\alpha$). Note: $ \partial_{tt} g_t = \partial_t v(g_t) = v'(g_t) \partial_t g_t = J(g_t) v(g_t) $. This is the error in a single step, meaning $ \Delta_{LTE} $ measures how far off the estimate of $g_t$ is from its real value after a single step. (This error will accumulate at every step, though $\delta$ will change at each step). This works out to be $$ \Delta_{LTE}(\delta) = \frac{1}{2} \eta^2 k^{4\alpha} k^{2\delta} \ln(k)^3 (k^\delta - 1)(2k^\delta - 1) \tag{1} $$ So let's say that $\alpha = 3$, $k = 4$, and $\eta = 10^{-4}$, which seem reasonable from the question. Then, if we start at say $\delta = 1$ (meaning we start from $g_t=4$), we get that $\Delta_{LTE}\approx 75$. This is off by quite a bit (we started a distance of 1 away, and now are a distance of ~75).

Essentially, equation (1) tells us that the error is exponential in $\delta$, meaning that as you go further to the right of $\alpha$ (i.e., as $\delta$ increases), the error increases exponentially quickly. This explains your quote

It seems like for any g even considerably greater than 3, the gradient blows up and ends up shooting the next guess way too far left.

As for the dependence on $k$, the error increases polynomially (in big O terms), but it's a nasty one. Suppose $\delta \approx 1$ and $\alpha \approx 3$. Then: $\Delta_{LTE} \in {\sim}O(k^{16} \ln(k)^3)$.

On the other hand, decreasing the learning rate $\eta$ will only shrink the error quadratically. This is probably not going to be very helpful for even moderately high values of $k$, $\alpha$, or $\delta$. In other words, the numerics when using this loss function are fundamentally quite cruel. :(