In this blog post, my objective is to clarify the proofs of Appendix B.2 of this paper. This is a follow up of my previous article about Neural ODEs detailing the maths but without much proofs. Here, I will rewrite the proofs of the derivative of the augmented adjoint state, that is, the gradient of the loss L w.r.t. z(t), \theta(t), and t.
Let
- f(z(t), \theta(t), t) = \dfrac{dz(t)}{dt}
- Z = \begin{bmatrix} z(t), & \theta(t), & t \end{bmatrix}
- a(t) = \dfrac{dL}{dZ(t)} = \begin{bmatrix} \dfrac{dL}{dz(t)}, & \dfrac{dL}{d\theta(t)}, & \dfrac{dL}{dt} \end{bmatrix} = \begin{bmatrix} a_z(t), & a_\theta(t), & a_t(t) \end{bmatrix}
- g(Z(t), t) = \dfrac{dZ(t)}{dt} = \begin{bmatrix} \dfrac{dz(t)}{dt}, & \dfrac{d\theta(t)}{dt}, & \dfrac{dt}{dt} \end{bmatrix} = \begin{bmatrix} f, & 0_\textbf{E}, & 1 \end{bmatrix} where 0_\textbf{E} is a null matrix of appropriate size
Then, we can define
Z(t+\varepsilon) = Z(t) + \int_{t}^{t+\varepsilon}{g(Z(t), t)dt} = T_\varepsilon(Z(t), t)
notice that
a(t) = \dfrac{dL}{dZ(t)} = \dfrac{dL}{dZ(t+\varepsilon)}\dfrac{dZ(t+\varepsilon)}{dZ(t)} = a(t+\varepsilon)\dfrac{\partial T_\varepsilon(Z(t), t)}{\partial Z(t)}
Thus we can derive a(t):
As you can see, the proof of appendix B.1 holds for any vector Z. What’s more interesting is the way we can exploit this to get \dfrac{da_\theta(t)}{dt}:
=\begin{bmatrix} a_z(t), & a_\theta(t), & a_t(t) \end{bmatrix} \cdot \begin{pmatrix}\dfrac{df(z(t), \theta(t), t)}{dz} & \dfrac{df(z(t), \theta(t), t)}{d\theta} & \dfrac{df(z(t), \theta(t), t)}{dt} \\ \dfrac{d\theta/dt}{dz} & \dfrac{d\theta/dt}{d\theta} & \dfrac{d\theta/dt}{dt} \\ \dfrac{dt/dt}{dz} & \dfrac{dt/dt}{d\theta} & \dfrac{dt/dt}{dt} \end{pmatrix}
But as mentioned in the paper, \dfrac{d\theta(t)}{dt} = 0 and \dfrac{dt}{dt} = 1 which means that the second and third rows are null (derivatives of constants). Hence:
=\begin{bmatrix} a_z(t), & a_\theta(t), & a_t(t) \end{bmatrix} \cdot \begin{pmatrix} \dfrac{df}{dz} & \dfrac{df}{d\theta} & \dfrac{df}{dt} \\ 0_\textbf{E} & 0_\textbf{E} & 0_\textbf{E} \\ 0_\textbf{E} & 0_\textbf{E} & 0_\textbf{E} \end{pmatrix} \\= -\begin{bmatrix}a_z(t)\dfrac{df}{dz} + a_\theta 0_\textbf{E} + a_t 0_\textbf{E} \\ a_z(t)\dfrac{df}{d\theta} + a_\theta 0_\textbf{E} + a_t 0_\textbf{E} \\ a_z(t)\dfrac{df}{dt} + a_\theta 0_\textbf{E} + a_t 0_\textbf{E} \end{bmatrix} = -\begin{bmatrix}a_z(t)\dfrac{df}{dz} \\ a_z(t)\dfrac{df}{d\theta} \\ a_z(t)\dfrac{df}{dt}\end{bmatrix}
Now, we’ve got:
\dfrac{da(t)}{dt} = - \begin{bmatrix} a_z(t)\dfrac{df}{dz}, & a_z(t)\dfrac{df}{d\theta}, & a_z(t)\dfrac{df}{dt} \end{bmatrix} = \dfrac{d}{dt}\begin{bmatrix} \dfrac{dL}{dz(t)}, & \dfrac{dL}{d\theta(t)}, & \dfrac{dL}{dt} \end{bmatrix}
Hence:
\dfrac{d}{dt}\left(\dfrac{dL}{d\theta(t)}\right) = -a_z(t)\dfrac{df}{d\theta}\\ \implies \dfrac{dL}{d\theta(t)} = -\int{a_z(t)\dfrac{df}{d\theta}dt}
QED
The PDF version of this is available at this link (with better formatting)