Proof of Neural Ordinary Differential Equation’s augmented state

In this blog post, my objective is to clarify the proofs of Appendix B.2 of this paper. This is a follow up of my previous article about Neural ODEs detailing the maths but without much proofs. Here, I will rewrite the proofs of the derivative of the augmented adjoint state, that is, the gradient of the loss L w.r.t. z(t), \theta(t), and t.


  1. f(z(t), \theta(t), t) = \dfrac{dz(t)}{dt}
  2. Z = \begin{bmatrix} z(t), & \theta(t), & t \end{bmatrix}
  3. a(t) = \dfrac{dL}{dZ(t)} = \begin{bmatrix} \dfrac{dL}{dz(t)}, & \dfrac{dL}{d\theta(t)}, & \dfrac{dL}{dt} \end{bmatrix} = \begin{bmatrix} a_z(t), & a_\theta(t), & a_t(t) \end{bmatrix}
  4. g(Z(t), t) = \dfrac{dZ(t)}{dt} = \begin{bmatrix} \dfrac{dz(t)}{dt}, & \dfrac{d\theta(t)}{dt}, & \dfrac{dt}{dt} \end{bmatrix} = \begin{bmatrix} f, & 0_\textbf{E}, & 1 \end{bmatrix} where 0_\textbf{E} is a null matrix of appropriate size

Then, we can define
Z(t+\varepsilon) = Z(t) + \int_{t}^{t+\varepsilon}{g(Z(t), t)dt} = T_\varepsilon(Z(t), t)
notice that
a(t) = \dfrac{dL}{dZ(t)} = \dfrac{dL}{dZ(t+\varepsilon)}\dfrac{dZ(t+\varepsilon)}{dZ(t)} = a(t+\varepsilon)\dfrac{\partial T_\varepsilon(Z(t), t)}{\partial Z(t)}
Thus we can derive a(t):

\frac{d a(t)}{d t} = \lim_{\varepsilon \rightarrow 0^{+}} \frac{a(t+\varepsilon)-a(t)}{\varepsilon} \\= \lim_{\varepsilon \rightarrow 0^{+}} \frac{a(t+\varepsilon)-a(t+\varepsilon) \frac{\partial}{\partial Z(t)} T_{\varepsilon}(Z(t), t)}{\varepsilon} \\=\lim_{\varepsilon \rightarrow 0^{+}} \frac{a(t+\varepsilon)-a(t+\varepsilon) \frac{\partial}{\partial Z(t)}\left(Z(t)+\varepsilon g(Z(t), t)+\mathcal{O}\left(\varepsilon^{2}\right)\right)}{\varepsilon}\\=\lim_{\varepsilon \rightarrow 0^{+}} \frac{a(t+\varepsilon)-a(t+\varepsilon)\left(I+\varepsilon \frac{\partial g(Z(t), t)}{\partial Z(t)}+\mathcal{O}\left(\varepsilon^{2}\right)\right)}{\varepsilon}\\=\lim_{\varepsilon \rightarrow 0^{+}} \frac{-\varepsilon a(t+\varepsilon) \frac{\partial g(Z(t), t)}{\partial Z(t)}+\mathcal{O}(\varepsilon^2)}{\varepsilon}\\=\lim_{\varepsilon \rightarrow 0^{+}} -a(t+\varepsilon) \frac{\partial g(Z(t), t)}{\partial Z(t)}+\mathcal{O}(\varepsilon)\\=-a(t) \frac{\partial g(Z(t), t)}{\partial Z(t)}

As you can see, the proof of appendix B.1 holds for any vector Z. What’s more interesting is the way we can exploit this to get \dfrac{da_\theta(t)}{dt}:

=\begin{bmatrix} a_z(t), & a_\theta(t), & a_t(t) \end{bmatrix} \cdot \begin{pmatrix}\dfrac{df(z(t), \theta(t), t)}{dz} & \dfrac{df(z(t), \theta(t), t)}{d\theta} & \dfrac{df(z(t), \theta(t), t)}{dt} \\ \dfrac{d\theta/dt}{dz} & \dfrac{d\theta/dt}{d\theta} & \dfrac{d\theta/dt}{dt} \\ \dfrac{dt/dt}{dz} & \dfrac{dt/dt}{d\theta} & \dfrac{dt/dt}{dt} \end{pmatrix}
But as mentioned in the paper, \dfrac{d\theta(t)}{dt} = 0 and \dfrac{dt}{dt} = 1 which means that the second and third rows are null (derivatives of constants). Hence:
=\begin{bmatrix} a_z(t), & a_\theta(t), & a_t(t) \end{bmatrix} \cdot \begin{pmatrix} \dfrac{df}{dz} & \dfrac{df}{d\theta} & \dfrac{df}{dt} \\ 0_\textbf{E} & 0_\textbf{E} & 0_\textbf{E} \\ 0_\textbf{E} & 0_\textbf{E} & 0_\textbf{E} \end{pmatrix} \\= -\begin{bmatrix}a_z(t)\dfrac{df}{dz} + a_\theta 0_\textbf{E} + a_t 0_\textbf{E} \\ a_z(t)\dfrac{df}{d\theta} + a_\theta 0_\textbf{E} + a_t 0_\textbf{E} \\ a_z(t)\dfrac{df}{dt} + a_\theta 0_\textbf{E} + a_t 0_\textbf{E} \end{bmatrix} = -\begin{bmatrix}a_z(t)\dfrac{df}{dz} \\ a_z(t)\dfrac{df}{d\theta} \\ a_z(t)\dfrac{df}{dt}\end{bmatrix}

Now, we’ve got:
\dfrac{da(t)}{dt} = - \begin{bmatrix} a_z(t)\dfrac{df}{dz}, & a_z(t)\dfrac{df}{d\theta}, & a_z(t)\dfrac{df}{dt} \end{bmatrix} = \dfrac{d}{dt}\begin{bmatrix} \dfrac{dL}{dz(t)}, & \dfrac{dL}{d\theta(t)}, & \dfrac{dL}{dt} \end{bmatrix}
\dfrac{d}{dt}\left(\dfrac{dL}{d\theta(t)}\right) = -a_z(t)\dfrac{df}{d\theta}\\ \implies \dfrac{dL}{d\theta(t)} = -\int{a_z(t)\dfrac{df}{d\theta}dt}


The PDF version of this is available at this link (with better formatting)

Neural Ordinary Differential Equations

This paper is rather not for the mathematically faint hearted, hence I’d advise a quick linear algebra and differential multivariable calculus review beforehand. To follow along the reviewing, you can toy with this amazing visualisation. The goal of this article is to get data scientists and programmers up to speed with the necessary mathematical background to understand this paper, not to merely restate what the paper says.


Layer paradigm

This paper presents a novel view of the whole neural network paradigm.
Up to that point, networks consisted mainly of assembling (sequentially or not) layers, which outputs are noted h_t. These layers are in reality functions depending on parameters (noted \theta_t) influencing their output in diverse manners. You’ve hopefully heard of the classical Dense/Logistic/Linear/Fully Connected layer which are usually followed by a non-linear activation function (sigmoïd, tanh, ReLU and many others), but also the convolutional layers used in CNN.

In the classical networks, these assembled layers are trained using gradient descent (and its variants) which is an algorithm that tries to optimise the loss function (noted L(h_{t+T}) ) by tuning the parameters of these layers.

To do that, it first computes the derivative of the loss (written \dfrac{dL(h_{t+T})}{dh_{t+T}} ) with respect to each intermediate layer’s output {h_t, h_{t+1}, h_{t+2}, ..., h_{t+T}} and then the derivative of the loss with respect to each parameter (\dfrac{dL(h_{t+T})}{d\theta_{t+T}} = \dfrac{dL(h_{t+T})}{dh_{t+T}} \dfrac{dh_{t+T}}{d\theta_{t+T}}), this is called back propagation (a.k.a. reverse mode derivation). The gradient descend algorithm then tunes the parameters by these derivatives.

To make predictions using those layers, we perform forward propagation, which is just applying consecutively each layer operation onto our input to get to our output. Formally, for one layer (the f function) with tuned parameters \theta_t this can be written like so:

h_{t+1} = f(h_t, \theta_t)

However, recently, a new kind of convolutional neural network has outperform significantly the state of the art by adding a simple little tweak: bypassing the layers. Their idea is as follows: once a sufficient depth is reached and the network judges it needn’t go deeper, it should be able to bypass himself and directly skip useless layers. To do that quickly, they simply added the previous input (the residual) to the output, hence were born Residual Neural Networks.
Mathematically, that tweak looks like this:

h_{t+1} = \bm{h_t} + f(h_t, \theta_t)

Differential Equations

To dive nose first into the mathematical requirements, let’s talk a bit about differential equations. Differential equations are mathematical objects which study dates back to the invention of calculus itself (chapter 2 of Newton’s 1671 work “Methodus fluxionum et Serierum Infinitarum”). They are equations relating a function to its derivative(s). This method allows physician to describe complex dynamical systems’ evolutions through time, like the Navier-Stokes equations allowing us to model weather somewhat accuratly.

To enter the mathematical formalities, what we call Ordinary Differential Equations (ODE) of the n-th order are equations of the following form:

a_1(t)z(t) + a_2(t)\dfrac{dz(t)}{dt} + a_3(t)\dfrac{d^{2}z(t)}{dt^2} + ... + a_n(t)\dfrac{d^{n}z(t)}{dt^n} = 0

This equation can be rewritten under the more digest and familiar form:

\dfrac{dz(t)}{dt} = f(z(t), t)

(where f(z(t), t) = -\frac{1}{a_2(t)}\left(a_1(t)z(t) + a_3(t)\dfrac{d^{2}z(t)}{dt^2} + ... + a_n(t)\dfrac{d^{n}z(t)}{dt^n}\right) but that’s beside the point.)

Notice that the above form doesn’t care about the ODE’s order, meaning any ODE can be written in that form, nay, any differential equation if we change f(z(t), t) for something else !

ODEs also allow to predict the future state(s) of a system given its current state. This kind of problem is called an Initial Value Problem and is very well studied in mathematics. A fundamental theorem in this kind of problem is the Picard-Lindelöf theorem (or Cauchy-Lipschitz theorem as I was taught). This theorem tells us that if f(z(t), t) meets certains conditions, the future states are unique for a given time.

It’s all well and good but what do ODE have to do with neural networks ?

That’s what you might be wondering at this point, and the authors address that question at the very beginning of their work but understanding that link requires understand both ODE and residual neural networks as described above.

To solve ODE, there are two methods: analytically and numerically. The former is what you’re usually taught in college, you use calculus and arithemtic tricks to get an analytical solution to the equation that you can then write on you test to get points. This usually ressemble z(t) = e^u(A\cos(vt) + B\sin(vt)), but sometimes reality is messy and maths can’t describe it exactly in a few complicated symbols. In that case we’re forced to use the latter: we compute a numerical solution that is essentially a list of values for a list of given times you want to know the value of z(t) at. To do that, we use solvers that use linear algebra (matrices and stuff) to simulate the passing of time from you initial condition.

The connection between that and neural network is the following:

z(t+1) = z(t) + f(z(t), \theta(t)) \\ \Leftrightarrow z(t+1) - z(t) = f(z(t), \theta(t)) \\ \approx \dfrac{dz(t)}{dt} = f(z(t), \theta(t))

Here, the derivative arises from its definition (the difference quotient):

\dfrac{dz(t)}{dt} = \lim_{\varepsilon\to 0}\frac{z(t+\varepsilon) - z(t)}{\varepsilon}

If we consider an approximation of \varepsilon = 1, then we get the Residual Neural Network’s formula where \frac{z(t+\varepsilon) - z(t)}{\varepsilon} = f(z(t), \theta(t)) but when \varepsilon \to 0, we get the derivative. We’re therefore left with a differential equation !

But the careful reader will notice the missing t in the function’s definition. Indeed, our layer still doesn’t know what time it is, so our derivative can’t change with respect to time. We can therefore give t to f(z(t), \theta(t)) to get the ODE that appears in the papers’ first page:

\dfrac{dz(t)}{dt} = f(z(t), t, \theta(t))

We changed h_t and \theta_t to z(t) and \theta(t) to signify that they’re not values at discrete timesteps t \in [\![0,T]\!] \subset \mathbb{N} but in a continuous range t \in [0 ,T] \subset \mathbb{R}.

This change of step from \varepsilon = 1 to \varepsilon \to 0 is illustrated in the paper’s Figure 1 :

NeuralODE figure 1 discrete vs continue

Figure 1: Left: A Residual network defines a discrete sequence of finite transformations. Right: A ODE network defines a vector field, which continuously transforms the state. Both: Circles represent evaluation locations.

With that realisation under our belt, we can start to devise strategies to optimise the way we train such neural networks.

How to get the gradients ?

To perform gradient descent, we need two gradients: \dfrac{\partial L}{\partial z(t_0)} and \dfrac{\partial L}{\partial \theta(t_0)}. Normally, you’d back propagate through the ODE solver’s operations to know how to tune \theta, but the authors have a better idea: let’s use a 1962 Lenin Prize of Science and Technology winning technique to compute it with constant memory cost and linear complexity !

Here is what they suggest. Suppose we have 3 values:

\left[\begin{matrix} z(t_1) & \dfrac{\partial L}{\partial z(t_1)} & \dfrac{\partial L}{\partial \theta(t_1)}\end{matrix}\right] (eventhough \dfrac{\partial L}{\partial \theta(t_1)} = 0, c.f. Appendix B.2 between eq 50 and 51)

Since we know how z(t) evolves in time (f(z(t), t, \theta)), and how \dfrac{\partial L}{\partial z(t)} and \dfrac{\partial L}{\partial \theta(t)} evolve in time too (you will see how we found that later I promise), we can give an ODE solver those 3 initial conditions (that we obtained at the end of the forward pass) and those 3 derivatives and ask it to compute back in time what these 3 quantities were at t_0, that is:

\left[\begin{matrix} z(t_0) & \dfrac{\partial L}{\partial z(t_0)} & \dfrac{\partial L}{\partial \theta(t_0)}\end{matrix}\right]

The adjoint method defines a function of time \mathbf{a}(t) = \dfrac{\partial L}{\partial z(t)} and the authors prove in Appendix B.1 that its derivative is given by:

\dfrac{d\mathbf{a}(t)}{dt} = - \mathbf {a}(t)^{\top}\frac{\partial f(z(t), t, \theta (t))}{\partial z}

The authors also generalises the proof in Appendix B.2 to define \mathbf{a_\theta}(t) = \dfrac{\partial L}{\partial \theta(t)} with derivative \dfrac{d\mathbf{a_\theta}(t)}{dt} = - \mathbf {a}(t)^{\top}\frac{\partial f(z(t), t, \theta (t))}{\bm{\partial \theta}}

In which \mathbf{a}(t) is still the same function \mathbf{a}(t) = \dfrac{\partial L}{\partial z(t)} (this and the above derivative puzzled me at first, here is the proof of appendix B.2 rewritten for clarity ) !

I won’t venture into the Contiunous Normalizing flow part of the paper here since it’d require an article of its own but there are some quite heavy maths involved as well. They are mostly well explained on this blog article series.

The most challenging part of these last two derivative are that \frac{\partial f(z(t), t, \theta (t))}{\partial z(t)} and \frac{\partial f(z(t), t, \theta (t))}{\partial \theta(t)} since both represent Jacobian Matrices.

Jacobian matrices are what link differential calculus and linear algebra together: they allow us to locally characterise the behaviour of a vector valued function, i.e. know what linear transformation is applied at each point in space, how each change in value along one input dimension affects each and every value on the output dimensions. The mathematical expression of the jacobian matrix is:

\frac { \partial f(z(t),t,\theta (t)) }{ \partial z(t) } =\left[ \begin{array}{ccc} { \frac { \partial f(z(t),t,\theta (t))_{ 1 } }{ \partial z(t)_{ 1 } } } & { \cdots } & { \frac { \partial f(z(t),t,\theta (t))_{ 1 } }{ \partial z(t)_{ { n } } } } \\ { \vdots } & { \ddots } & { \vdots } \\ { \frac { \partial f(z(t),t,\theta (t))_{ m } }{ \partial z(t)_{ { 1 } } } } & { \cdots } & { \frac { \partial f(z(t),t,\theta (t))_{ m } }{ \partial z(t)_{ { n } } } } \end{array} \right] \in \mathbb{R^{m \times n}}

This is only a concern for vector valued function such as here since
f : \mathbb{R}^{n_z} \times \mathbb{R} \times \mathbb{R}^{n_\theta} \to \mathbb{R}^{n_z}

Now that you’re familliar with jacobian matrices, you understand why -\mathbf {a}(t)^{\top}\frac{\partial f(z(t), t, \theta (t))}{\partial z} is refered in the paper as a Vector-Jacobian Product (named vjp in the code). More information can be found on this amazing blog post.


Ok, first things first, the authors provide two things for us right off the bat: a pseudo-code example (Appendix C) detailing the high level working of their code. Secondly, in Appendix D they also give us a rather large code snippet (implemented on Autograd in Python) which runs an ODE solver backward in time starting from a given augmented state (a vector) which contains \left[\begin{matrix} z(t_1) & \dfrac{\partial L}{\partial z(t_1)} & \dfrac{\partial L}{\partial \theta(t_1)}\end{matrix}\right] and return a list of vectors of gradients at the requested timesteps \left[\begin{matrix} \dfrac{\partial L}{\partial z(t_0)} & \dfrac{\partial L}{\partial \theta(t_0)}\end{matrix}\right]

Lastly and more importantly, the authors actively maintain a GitHub repo which host a Python package which makes using the adjoint method onto an arbitrary function f a breeze (as long as you’re comfortable using PyTorch, but it’s not really hard to pick up).

To those already familliar with that framework, to use the torchdiffeq library in your project, you simply need to define a new torch.nn.Module class and implement the forward method which ought to take two arguments: the first is the time t at which the solver evaluates the function, the second is the state (z(t)) for which the solver wants to know the derivative. Then, you can take these two inputs, combine them as you please and return something that will learn to model your ODE.

To train it, you can use torchdiffeq‘s odeint function in another module’s forward method and return its output as the Module’s output. Then, in the main training loop, just use any old optimiser and loss function you like and when you call loss.backward() pytorch’s autograd will call torchdiffeq‘s .backward() overloading in the process and the magic will happen behind the scene and the adjoint method will automatically compute and return its gradient which can then be chained with other traditional layers from torch.nn. All of this will happen several times per second on high ends GPUs and your ODE Model will train to fit the actual ODE (described only partially by your data).

To get familiar with the usage of the torchdiffeq library, I’d suggest you to read through and rewrite from the the ground up their example as it covers the basic usage of the library’s odeint function. Don’t hesitate to also look in the library’s workings itself too !