Neural ODE definition of derivative $\frac{d L}{dz(t)}$ (adjoint)

458 Views Asked by At

The authors of Neural Ordinary Differential Equations (NeuRIPS 2018 best paper award), propose to model machine learning problems with an ODE

$$ \dot z(t) = f(t, z(t), \theta) \quad{\text{s.t.}}\quad z(t_0)=z_0$$

where $f$ is e.g. a neural network with parameters $\theta$. The Gradient with respect to a loss function $L$ is computed via the adjoint method:

$$\begin{aligned} \frac{d L}{d\theta} = -\int_{t_1}^{t_0}a(t)^T\frac{\partial f}{\partial \theta} d t \quad\text{where}\quad \dot a(t) = -a(t)^T\frac{\partial f}{\partial z},\, a(t_1)=\frac{dL}{dz(t_1)} \end{aligned}$$

I have a problem with interpreting the adjoint state $a(t)$. The paper claims that the adjoint state is given by $a(t)=\frac{dL}{dz(t)}$. However, let's consider the simple example$\,\dot z = w z,\, z(t_0)=z_0$ with solution $\hat z(t) := z^*(t;t_0, z_0, w) = e^{w (t-t_0)}z_0$ and loss function $L(w)=\frac{1}{2}|\hat z(t_1) - z_1|^2$. Then, solving the adjoint system yields:

$$\left.\begin{aligned} \dot a(t) &= - a(t)^T \tfrac{\partial f}{\partial z} = -w \cdot a(t) \\ a(t_1) &= \frac{d L }{d z(t_1)} = \hat z(t_1) - z_1 \end{aligned} \right\rbrace\implies a(t) = e^{-w(t-t_1)}(\hat z(t_1) - z_1) $$

I veried that this is indeed correct, integrating $-\int_{t_1}^{t_0} a(t)\frac{\partial f}{\partial\theta}dt$ gives the correct derivative $\frac{dL}{dw}$.

Plugging back the solution $\hat z$ of the ODE yields $a(t)=e^{-w(t-t_1)}(e^{w (t_1-t_0)}z_0 - z_1)$ (I think I got all the signs correct here.)

However, I do not see how to mechanistically form the derivative $\frac{d L}{dz(t)}$ directly and arrive at the same result.

If we interpret $\frac{dL}{dz(t)}$ as $\frac{d}{dz}(\frac{1}{2}|z-z_1|^2)\big|_{z=\hat z(t)}$ we get $\hat z(t)-z_1$, which is unequal to $a(t)$. So how is $\frac{dL}{dz(t)}$ formally defined, such that it ends up equal ? The paper never explains this.

Secondly, I would like to know if there is a more direct way to derive the adjoint equation. I am aware of the author's derivation in the appendix as well as the derivation via Pontryagin's maximum principle. The first one appears unmotived whilst the second one for some reason requires the consideration of an optimization problem, when all we want to do is compute a derivative.

1

There are 1 best solutions below

1
On

As a first result to resolve the conundrum. Let $\hat z(t)=z^*(t, t_0, z_0, \theta)$, with fixed initial condition $t_0, z_0$, then it appears that the (in)correct interpretation of the term is given by:

$$\begin{alignedat}{3} \text{incorrect:}&& \frac{dL}{dz(t)} &\overset{\text{def}}{=} \frac{d}{d\hat z(t)}L\big(\underbrace{z^{*}(t;t_0,z_0,\theta)}_{=\hat z(t)}\big) &&\overset{\text{def}}{=}\frac{\partial}{\partial z}L(z)\Big|_{z=\hat z(t)} \\ \text{correct:}&& \frac{dL}{dz(t)} &\overset{\text{def}}{=} \frac{d}{d\hat z(t)}L\big(z^{*}(t_1;t,\hat z(t),\theta\big) &&\overset{\text{def}}{=} \frac{\partial}{\partial z}L\big(z^*(t_1; t,z, \theta)\big)\Big|_{z=\hat z(t)} \end{alignedat}$$

We can observe that in the example scenario

$$\begin{aligned} \frac{\partial}{\partial z}L(z)\Big|_{z=\hat z(t)} &= \frac{\partial}{\partial z}\Big[\frac{1}{2}\|z-z_1\|_2^2\Big]\Bigg|_{z=\hat z(t)} \\&= (z-z_1)\Big|_{z=\hat z(t)} \\&= \hat z(t)-z_1 = e^{w(t_-t_0)}z_0 - z_1{\huge\color{red}{↯}} \\ \frac{\partial}{\partial z}L(z^*(t_1; t,z, \theta))\Big|_{z=\hat z(t)} &=\frac{\partial}{\partial z}\Big[\frac{1}{2}\|e^{w(t_1-t)}z-z_1\|_2^2\Big]\Bigg|_{z=\hat z(t)} \\&= e^{w(t_1-t)}(e^{w(t_1-t)}z-z_1)\Big|_{z=\hat z(t)} \\&= e^{w(t_1-t)}(e^{w(t_1-t)}\hat z(t)-z_1) \\&=e^{w(t_1-t)}(e^{w(t_1-t)}e^{w(t-t_0)}z_0-z_1) \\&=e^{w(t_1-t)}(e^{w(t_1-t_0)}z_0-z_1) \\&=e^{w(t_1-t)}(\hat z(t_1)-z_1) {\huge\color{green}{\checkmark}} \end{aligned}$$