Proof of Neural Ordinary Differential Equation’s augmented state

In this blog post, my objective is to clarify the proofs of Appendix B.2 of this paper. This is a follow up of my previous article about Neural ODEs detailing the maths but without much proofs. Here, I will rewrite the proofs of the derivative of the augmented adjoint state, that is, the gradient of the loss L w.r.t. z(t), \theta(t), and t.

Let

  1. f(z(t), \theta(t), t) = \dfrac{dz(t)}{dt}
  2. Z = \begin{bmatrix} z(t), & \theta(t), & t \end{bmatrix}
  3. a(t) = \dfrac{dL}{dZ(t)} = \begin{bmatrix} \dfrac{dL}{dz(t)}, & \dfrac{dL}{d\theta(t)}, & \dfrac{dL}{dt} \end{bmatrix} = \begin{bmatrix} a_z(t), & a_\theta(t), & a_t(t) \end{bmatrix}
  4. g(Z(t), t) = \dfrac{dZ(t)}{dt} = \begin{bmatrix} \dfrac{dz(t)}{dt}, & \dfrac{d\theta(t)}{dt}, & \dfrac{dt}{dt} \end{bmatrix} = \begin{bmatrix} f, & 0_\textbf{E}, & 1 \end{bmatrix} where 0_\textbf{E} is a null matrix of appropriate size

Then, we can define
Z(t+\varepsilon) = Z(t) + \int_{t}^{t+\varepsilon}{g(Z(t), t)dt} = T_\varepsilon(Z(t), t)
notice that
a(t) = \dfrac{dL}{dZ(t)} = \dfrac{dL}{dZ(t+\varepsilon)}\dfrac{dZ(t+\varepsilon)}{dZ(t)} = a(t+\varepsilon)\dfrac{\partial T_\varepsilon(Z(t), t)}{\partial Z(t)}
Thus we can derive a(t):

\frac{d a(t)}{d t} = \lim_{\varepsilon \rightarrow 0^{+}} \frac{a(t+\varepsilon)-a(t)}{\varepsilon} \\= \lim_{\varepsilon \rightarrow 0^{+}} \frac{a(t+\varepsilon)-a(t+\varepsilon) \frac{\partial}{\partial Z(t)} T_{\varepsilon}(Z(t), t)}{\varepsilon} \\=\lim_{\varepsilon \rightarrow 0^{+}} \frac{a(t+\varepsilon)-a(t+\varepsilon) \frac{\partial}{\partial Z(t)}\left(Z(t)+\varepsilon g(Z(t), t)+\mathcal{O}\left(\varepsilon^{2}\right)\right)}{\varepsilon}\\=\lim_{\varepsilon \rightarrow 0^{+}} \frac{a(t+\varepsilon)-a(t+\varepsilon)\left(I+\varepsilon \frac{\partial g(Z(t), t)}{\partial Z(t)}+\mathcal{O}\left(\varepsilon^{2}\right)\right)}{\varepsilon}\\=\lim_{\varepsilon \rightarrow 0^{+}} \frac{-\varepsilon a(t+\varepsilon) \frac{\partial g(Z(t), t)}{\partial Z(t)}+\mathcal{O}(\varepsilon^2)}{\varepsilon}\\=\lim_{\varepsilon \rightarrow 0^{+}} -a(t+\varepsilon) \frac{\partial g(Z(t), t)}{\partial Z(t)}+\mathcal{O}(\varepsilon)\\=-a(t) \frac{\partial g(Z(t), t)}{\partial Z(t)}

As you can see, the proof of appendix B.1 holds for any vector Z. What’s more interesting is the way we can exploit this to get \dfrac{da_\theta(t)}{dt}:

=\begin{bmatrix} a_z(t), & a_\theta(t), & a_t(t) \end{bmatrix} \cdot \begin{pmatrix}\dfrac{df(z(t), \theta(t), t)}{dz} & \dfrac{df(z(t), \theta(t), t)}{d\theta} & \dfrac{df(z(t), \theta(t), t)}{dt} \\ \dfrac{d\theta/dt}{dz} & \dfrac{d\theta/dt}{d\theta} & \dfrac{d\theta/dt}{dt} \\ \dfrac{dt/dt}{dz} & \dfrac{dt/dt}{d\theta} & \dfrac{dt/dt}{dt} \end{pmatrix}
But as mentioned in the paper, \dfrac{d\theta(t)}{dt} = 0 and \dfrac{dt}{dt} = 1 which means that the second and third rows are null (derivatives of constants). Hence:
=\begin{bmatrix} a_z(t), & a_\theta(t), & a_t(t) \end{bmatrix} \cdot \begin{pmatrix} \dfrac{df}{dz} & \dfrac{df}{d\theta} & \dfrac{df}{dt} \\ 0_\textbf{E} & 0_\textbf{E} & 0_\textbf{E} \\ 0_\textbf{E} & 0_\textbf{E} & 0_\textbf{E} \end{pmatrix} \\= -\begin{bmatrix}a_z(t)\dfrac{df}{dz} + a_\theta 0_\textbf{E} + a_t 0_\textbf{E} \\ a_z(t)\dfrac{df}{d\theta} + a_\theta 0_\textbf{E} + a_t 0_\textbf{E} \\ a_z(t)\dfrac{df}{dt} + a_\theta 0_\textbf{E} + a_t 0_\textbf{E} \end{bmatrix} = -\begin{bmatrix}a_z(t)\dfrac{df}{dz} \\ a_z(t)\dfrac{df}{d\theta} \\ a_z(t)\dfrac{df}{dt}\end{bmatrix}

Now, we’ve got:
\dfrac{da(t)}{dt} = - \begin{bmatrix} a_z(t)\dfrac{df}{dz}, & a_z(t)\dfrac{df}{d\theta}, & a_z(t)\dfrac{df}{dt} \end{bmatrix} = \dfrac{d}{dt}\begin{bmatrix} \dfrac{dL}{dz(t)}, & \dfrac{dL}{d\theta(t)}, & \dfrac{dL}{dt} \end{bmatrix}
Hence:
\dfrac{d}{dt}\left(\dfrac{dL}{d\theta(t)}\right) = -a_z(t)\dfrac{df}{d\theta}\\ \implies \dfrac{dL}{d\theta(t)} = -\int{a_z(t)\dfrac{df}{d\theta}dt}

QED

The PDF version of this is available at this link (with better formatting)