Why can't we derive Pontryagin's adjoint method by the chain rule?

438 Views Asked by At

I'm trying to read the supplement to the NeurIPS 2018 best paper winner, Neural Ordinary Differential Equations, and am having trouble understanding the derivation of Pontryagin's adjoint method.

So let's say we have a function

(1)$L(z(t))$, with

(2)$\frac{dz(t)}{dt} = f(z(t),t,\theta)$

and we define the adjoint,

(3) $a(t) = \frac{dL}{dz}$.

Pontryagin's adjoint theorem then says that then

(4) $\frac{da(t)}{dt} = -a(t)\frac{df(z(t),t,\theta)}{dz}$.

I find this very confusing, because I seem to get the wrong answer if I try to use the chain rule directly. (Presumably I misunderstand the chain rule).

My approach:

$\frac{da}{dt} = \frac{\frac{dL}{dz}}{dt}$ by (3),

which is $\frac{L_z (z(t))}{dt}$, which gives

$L_{zz}(z(t))\frac{dz}{dt}$, and from (2) we have $\frac{da}{dt} = L_{zz}(z(t)) f(z(t),t,\theta)$.

...this final expression is not what Pontryagin's expression in (4) says.

Any thoughts on what I'm doing wrong?

Thanks!

1

There are 1 best solutions below

0
On BEST ANSWER

You have only started manipulating, you need to work harder to get to the desired form. $L_{zz}$ can be further manipulated. You will find you want $\frac{\partial\dot{L}}{\partial z}=0$.

The usual and more intuitive way to derive the costate equation is as follows. If we perturb $z$ by $\delta z$ and flow forward for time $\delta t$, the separate expansion to order $\delta t$ and order $\delta z$ gives (Note we do not assume $\delta t\delta z$ is negligible), \begin{align*} \require{color} z(t)+\delta z+f(z(t)+\delta z,t,\theta)\delta t &={\color{red}z(t)}+\delta z+\left({\color{red}f(z(t),t,\theta)} + \frac{\partial f}{\partial z}(z(t),t,\theta) \delta z\right) \delta t\\ &={\color{red}z(t + \delta t)} + (1 + \frac{\partial f}{\partial z}(z(t),t,\theta) \delta t) \delta z. \end{align*} In other words, the perturbation $\delta z$ at time $t$ propagates to a perturbation $(1 + \frac{\partial f}{\partial z}(z(t),t,\theta) \delta t) \delta z$ at time $t+\delta t$. Hence the costate $a(t)=a(t+\delta t)(1+\frac{\partial f}{\partial z}(z(t),t,\theta)\delta t)$. Equating terms of order $\delta t$ gives $$ 0=a'(t)+a(t)\frac{\partial f}{\partial z}(z(t),t,\theta). $$