Proof of Neural Ordinary Differential Equation’s augmented state

In this blog post, my objective is to clarify the proofs of Appendix B.2 of this paper. This is a follow up of my previous article about Neural ODEs detailing the maths but without much proofs. Here, I will rewrite the proofs of the derivative of the augmented adjoint state, that is, the gradient of the loss $L$ w.r.t. $z(t)$, $\theta(t)$, and $t$.

Let

1. $f(z(t), \theta(t), t) = \dfrac{dz(t)}{dt}$
2. $Z = \begin{bmatrix} z(t), & \theta(t), & t \end{bmatrix}$
3. $a(t) = \dfrac{dL}{dZ(t)} = \begin{bmatrix} \dfrac{dL}{dz(t)}, & \dfrac{dL}{d\theta(t)}, & \dfrac{dL}{dt} \end{bmatrix} = \begin{bmatrix} a_z(t), & a_\theta(t), & a_t(t) \end{bmatrix}$
4. $g(Z(t), t) = \dfrac{dZ(t)}{dt} = \begin{bmatrix} \dfrac{dz(t)}{dt}, & \dfrac{d\theta(t)}{dt}, & \dfrac{dt}{dt} \end{bmatrix} = \begin{bmatrix} f, & 0_\textbf{E}, & 1 \end{bmatrix}$ where $0_\textbf{E}$ is a null matrix of appropriate size

Then, we can define
$$Z(t+\varepsilon) = Z(t) + \int_{t}^{t+\varepsilon}{g(Z(t), t)dt} = T_\varepsilon(Z(t), t)$$
notice that
$$a(t) = \dfrac{dL}{dZ(t)} = \dfrac{dL}{dZ(t+\varepsilon)}\dfrac{dZ(t+\varepsilon)}{dZ(t)} = a(t+\varepsilon)\dfrac{\partial T_\varepsilon(Z(t), t)}{\partial Z(t)}$$
Thus we can derive $a(t)$:

$$\frac{d a(t)}{d t} = \lim_{\varepsilon \rightarrow 0^{+}} \frac{a(t+\varepsilon)-a(t)}{\varepsilon} \\= \lim_{\varepsilon \rightarrow 0^{+}} \frac{a(t+\varepsilon)-a(t+\varepsilon) \frac{\partial}{\partial Z(t)} T_{\varepsilon}(Z(t), t)}{\varepsilon} \\=\lim_{\varepsilon \rightarrow 0^{+}} \frac{a(t+\varepsilon)-a(t+\varepsilon) \frac{\partial}{\partial Z(t)}\left(Z(t)+\varepsilon g(Z(t), t)+\mathcal{O}\left(\varepsilon^{2}\right)\right)}{\varepsilon}\\=\lim_{\varepsilon \rightarrow 0^{+}} \frac{a(t+\varepsilon)-a(t+\varepsilon)\left(I+\varepsilon \frac{\partial g(Z(t), t)}{\partial Z(t)}+\mathcal{O}\left(\varepsilon^{2}\right)\right)}{\varepsilon}\\=\lim_{\varepsilon \rightarrow 0^{+}} \frac{-\varepsilon a(t+\varepsilon) \frac{\partial g(Z(t), t)}{\partial Z(t)}+\mathcal{O}(\varepsilon^2)}{\varepsilon}\\=\lim_{\varepsilon \rightarrow 0^{+}} -a(t+\varepsilon) \frac{\partial g(Z(t), t)}{\partial Z(t)}+\mathcal{O}(\varepsilon)\\=-a(t) \frac{\partial g(Z(t), t)}{\partial Z(t)}$$

As you can see, the proof of appendix B.1 holds for any vector $Z$. What’s more interesting is the way we can exploit this to get $\dfrac{da_\theta(t)}{dt}$:

$$=\begin{bmatrix} a_z(t), & a_\theta(t), & a_t(t) \end{bmatrix} \cdot \begin{pmatrix}\dfrac{df(z(t), \theta(t), t)}{dz} & \dfrac{df(z(t), \theta(t), t)}{d\theta} & \dfrac{df(z(t), \theta(t), t)}{dt} \\ \dfrac{d\theta/dt}{dz} & \dfrac{d\theta/dt}{d\theta} & \dfrac{d\theta/dt}{dt} \\ \dfrac{dt/dt}{dz} & \dfrac{dt/dt}{d\theta} & \dfrac{dt/dt}{dt} \end{pmatrix}$$
But as mentioned in the paper, $\dfrac{d\theta(t)}{dt} = 0$ and $\dfrac{dt}{dt} = 1$ which means that the second and third rows are null (derivatives of constants). Hence:
$$=\begin{bmatrix} a_z(t), & a_\theta(t), & a_t(t) \end{bmatrix} \cdot \begin{pmatrix} \dfrac{df}{dz} & \dfrac{df}{d\theta} & \dfrac{df}{dt} \\ 0_\textbf{E} & 0_\textbf{E} & 0_\textbf{E} \\ 0_\textbf{E} & 0_\textbf{E} & 0_\textbf{E} \end{pmatrix} \\= -\begin{bmatrix}a_z(t)\dfrac{df}{dz} + a_\theta 0_\textbf{E} + a_t 0_\textbf{E} \\ a_z(t)\dfrac{df}{d\theta} + a_\theta 0_\textbf{E} + a_t 0_\textbf{E} \\ a_z(t)\dfrac{df}{dt} + a_\theta 0_\textbf{E} + a_t 0_\textbf{E} \end{bmatrix} = -\begin{bmatrix}a_z(t)\dfrac{df}{dz} \\ a_z(t)\dfrac{df}{d\theta} \\ a_z(t)\dfrac{df}{dt}\end{bmatrix}$$

Now, we’ve got:
$$\dfrac{da(t)}{dt} = - \begin{bmatrix} a_z(t)\dfrac{df}{dz}, & a_z(t)\dfrac{df}{d\theta}, & a_z(t)\dfrac{df}{dt} \end{bmatrix} = \dfrac{d}{dt}\begin{bmatrix} \dfrac{dL}{dz(t)}, & \dfrac{dL}{d\theta(t)}, & \dfrac{dL}{dt} \end{bmatrix}$$
Hence:
$$\dfrac{d}{dt}\left(\dfrac{dL}{d\theta(t)}\right) = -a_z(t)\dfrac{df}{d\theta}\\ \implies \dfrac{dL}{d\theta(t)} = -\int{a_z(t)\dfrac{df}{d\theta}dt}$$

QED