Differentiating through Optimization Paths

64 Views Asked by At

I'm reading the paper "Optimizing Millions of Hyperparameters by Implicit Differentiation". The key contribution of the paper is to show that you can replace optimizing through the optimization process/path by using implicit gradients to effectively optimize hyper-parameters.

Problem:

I don't quite understand how the optimization path fits into the derivation. For example, per their derivation, let's assume that we want to optimize some hyper-paramaeters $\lambda$, where $w^{*}(\lambda)$ is the locally optimal base-parameters using some value of $\lambda$. Furthermore, let us assume we are using some vanilla gradient descent style optimization process for both the inner and outer optimization.

\begin{align} \lambda_{new} &= \lambda - \frac{\partial}{\partial \lambda}\mathcal{L}(w^{*}(\lambda)) \\ &= \lambda - \frac{\partial}{\partial w^{*}(\lambda)} \mathcal{L}(w^{*}(\lambda)) \cdot \frac{\partial}{\partial \lambda} w^{*}(\lambda) \end{align}

Where $\frac{\partial}{\partial\lambda}w^{*}(\lambda)$ can be replaced with some closed-form solution as follows using the implicit gradients theorm:

\begin{align} \frac{\partial}{\partial\lambda}\bigg[\frac{\partial}{\partial w} \mathcal{L}(w(\lambda), \lambda) \bigg] & = 0 \\ \frac{\partial}{\partial w}\bigg[\frac{\partial}{\partial\lambda} \mathcal{L}(w(\lambda), \lambda) \bigg] & = 0 \\ \frac{\partial}{\partial w}\bigg[\frac{\partial\mathcal{L}}{\partial w} \cdot \frac{\partial w}{\partial\lambda} + \frac{\partial\mathcal{L}}{\partial\lambda} \bigg] & = 0 \\ \frac{\partial^{2}\mathcal{L}}{\partial w \partial w^{T}} \cdot \frac{\partial^{2} w}{\partial\lambda\partial w^{T}} + \frac{\partial^{2}\mathcal{L}}{\partial\lambda\partial w^{T}} & = 0 \\ \frac{\partial^{2}\mathcal{L}}{\partial w \partial w^{T}} \cdot \frac{\partial w}{\partial\lambda} + \frac{\partial^{2}\mathcal{L}}{\partial\lambda\partial w^{T}} & = 0 \\ \frac{\partial^{2}\mathcal{L}}{\partial w \partial w^{T}} \cdot \frac{\partial w}{\partial\lambda} & = - \frac{\partial^{2}\mathcal{L}}{\partial\lambda\partial w^{T}} \\ \frac{\partial w}{\partial\lambda} & = - \bigg[\frac{\partial^{2}\mathcal{L}}{\partial w \partial w^{T}}\bigg]^{-1} \cdot \frac{\partial^{2}\mathcal{L}}{\partial\lambda\partial w^{T}} \end{align}

However, I don't understand how $w^{*}(\lambda)$ gets transformed into the following:

\begin{align} \frac{\partial}{\partial \lambda} w^{*}(\lambda) &= \cdots \\ &= \cdots \\ &= \frac{\partial}{\partial\lambda}\bigg[\frac{\partial}{\partial w} \mathcal{L}(w(\lambda), \lambda) \bigg] \end{align}

Intuitively, $w^{*}(\lambda)$ is the result of the optimization path taking $k$ gradient steps:

\begin{align} w_{0} &= \cdots\\ w_{1} &= w_{0} - \frac{\partial}{\partial w_{0}} \mathcal{L}(w_{0}(\lambda)) \\ w_{2} &= w_{1} - \frac{\partial}{\partial w_{1}} \mathcal{L}(w_{1}(\lambda)) \\ \vdots \\ w^{*} &= w_{k-1} - \frac{\partial}{\partial w_{k-1}} \mathcal{L}(w_{k-1}(\lambda)) \\ \end{align}

Which we can substitute back into our equation as follows

$$\frac{\partial}{\partial \lambda} w^{*}(\lambda) = \frac{\partial}{\partial \lambda} \bigg[w_0 - \sum^{k}_{i=1}\frac{\partial}{\partial w_{k-1}} \mathcal{L}(w_{k-1}(\lambda))\bigg]$$

However, this isn't the quantity we want since

$$\frac{\partial}{\partial\lambda}\bigg[\frac{\partial}{\partial w} \mathcal{L}(w(\lambda), \lambda) \bigg] \neq \frac{\partial}{\partial \lambda} \bigg[w_0 - \sum^{k}_{i=1}\frac{\partial}{\partial w_{k-1}} \mathcal{L}(w_{k-1}(\lambda))\bigg]$$

Can anyone help resolve this issue for me. Thank you very much.