I'm trying to understand the math behind the backward pass in Neural ODEs.
I understand it boils down to solving an IVP backwards in time, for a state $\mathbf a(t)$ (the so called adjoint state) such that$$\mathbf a(t) = \frac{dL}{d\mathbf z(t)},$$ where $L(.)$ is some loss function. From this, it's possibe to derive $$\frac{d\mathbf a(t)}{dt} = -\mathbf a(t)\frac{\partial f(\mathbf z(t), t, \theta)}{\partial t}.$$
Assuming $L$ only depends on the hidden state at the final time point, z($t_N$), the gradients w.r.t $z(t_N)$ are trivial. To get gradients w.r.t $z(t_0)$, we must have stored $z(t_N)$ from the forward pass, and we solve the backwards IVP: $$\mathbf a(t_0) = \mathbf a(t_N) + \int_{t_N}^{t_0} \frac{d\mathbf a(t)}{dt}dt = \mathbf a(t_N) - \int_{t_N}^{t_0} \mathbf a(t)^T\frac{\partial f(\mathbf z(t), t, \theta)}{\partial \mathbf z(t)}dt. $$
To get gradients w.r.t $t$ and $\theta$, the adjoint state is augmented with $\theta(t)$ (constant), and $t(t)$, which have trivial derivatives $\frac{d\theta(t)}{dt} = 0$ and $\frac{dt(t)}{dt} = 1$. We obtain (omitting some steps, which can be found in Appendix B in the paper): $$ \frac{d\mathbf a_{aug}}{dt} = -[\mathbf a(t) \ \mathbf a_\theta(t) \ \mathbf a_t(t)]\frac{\partial f_{aug}}{\partial [\mathbf z, \theta, t]}(t) = - [\mathbf a \frac{\partial f}{\partial \mathbf z} \ \mathbf a \frac{\partial f}{\partial \theta} \ \mathbf a \frac{\partial f}{\partial t}],$$
Where the first element is our adjoint state before augmentation, the second element gives us gradients w.r.t $\theta$ and the third element gives us gradients w.r.t $t$.
This is all straightforward to me, as long as $L$ only depends (directly) on $\mathbf z(t_N)$. If instead we have a loss that depends directly on $\mathbf z(.)$ at additional intermediate time points, the authors state that one should simply repeat this procedure for each pair of consecutive time points and sum up the obtained gradients.
My intuition is a bit lost as to why this is the case. I think part of it is the fact that future points can't influence past time points, but from Figure 2 in the paper it looks as though when we're going (backwards) from $\mathbf z(t_{i+1})$ to $\mathbf z(t_{i})$ and then "jump" there to repeat the procedure from $\mathbf z(t_{i})$ to $\mathbf z(t_{i-1})$, we "keep" $\mathbf z_{i-1}$ from influencing $\mathbf z_{i+1}$.
Bluntly and to summarize: the paper states that "If the loss depends directly on the state at multiple observation times, the adjoint state must be updated in the direction of the partial derivative of the loss with respect to each observation.", but I lack some intuition as to why that is the case.
To compute $dL/d\theta$, let’s write the loss down for simplicity case considering scalar $z$:
$$L(\theta) = L(z(t_0, \theta), … ,z(t_N, \theta))$$
$L$ doesn’t depend on $t$; therefore, we can treat individual functions $z_i(\theta) = z(t_i, \theta)$ as independent functions of $\theta$. While we define $z(t, \theta)$ as a solution to an ODE, mathematically, it is just a function of two variables, $z_i$ is not a function of $z_{i-1}$, they are just values of $z$ at two different points. We compute the total derivative of $L$ in terms of partial derivatives of $L$ regarding $z_i$ and derivatives of $z_i$:
$$ \frac{dL}{d\theta} = \sum \frac{\partial L}{\partial z_i} \frac{dz_i}{d\theta}. $$
The right-hand side corresponds to a sum of $dL/d\theta$ calculated as if the loss depended only on individual $z_i$ and can be computed by running Algorithm 1 from the NeuralODE paper $N$ times for intervals $(t_0, t_i)$. This can be avoided if we take advantage of the sum of derivatives being the derivative of sum. Let’s write down each initial value problem for the adjoints:
$$ \begin{cases}\frac{d a_i }{dt}= -a_i\frac{\partial f(z(t),t,\theta)}{\partial z}\\ a_i(t_i) = \frac{\partial L}{\partial z(t_i)}\end{cases}. $$
Note that trajectories $z$ are the same for all components, but the initial values are different.
$$ \frac{dL}{d\theta} = - \sum_{i=0}^N \int_{t_i}^{t_0}a_i(t)\frac{\partial f(z(t),t,\theta)}{\partial\theta}dt $$
Now we want to compute one integral, instead of a sum of integrals. To do that, we combine $a_i(t)$ into a single $a(t)$:
$$ a(t) = \sum_{i=0}^N a_i(t) H(t_i - t), $$
where $H$ is the Heaviside step function. The ODE for $a(t)$ is as follows:
$$ \frac{d a}{dt}= \sum \left[\frac{d a_i }{dt} H(t_i-t)-a_i(t)\delta(t_i-t)\right]=\sum \left[-a_i(t)\frac{\partial f(z(t),t,\theta)}{\partial z}H(t_i-t)-a_i(t)\delta(t_i-t)\right]=-\frac{\partial f(z(t),t,\theta)}{\partial z}\sum a_i(t)H(t_i-t)-\sum a_i(t)\delta(t_i-t)=-a(t)\frac{\partial f(z(t),t,\theta)}{\partial z}-\sum a_i(t)\delta(t_i-t) $$
Initial value is $a(t_N) = a_N(t_N) = \partial L / \partial z(t_N)$.
Computationally, we accommodate the delta-function jumps in $a(t)$ in the reverse-mode derivative by breaking integration into a sequence of separate segments. At each intermediate point, the adjoint $a$ is adjusted in the direction of the corresponding partial derivative $\mathbb{a}_i(t_i) =\partial L /\partial \mathbb{z}(t_i)$.