Neural ODEs from Scratch
1. Ordinary Differential Equations from First Principles
A ResNet computes h1 = h0 + f(h0), then h2 = h1 + f(h1), and so on through dozens of layers. Stare at that long enough and something clicks: this is Euler's method for numerically solving the differential equation dh/dt = f(h(t)), with a step size of 1. The forward pass of a deep residual network is an ODE solver. Neural ODEs take this observation seriously.
An ordinary differential equation describes how a state h evolves over a continuous variable t: dh/dt = f(h, t). Given an initial condition h(0) = h0, the solution traces a trajectory through state space. This is called an initial value problem (IVP).
The simplest numerical solver is Euler's method: advance by a small step Δt using the current derivative. It's fast but inaccurate — errors accumulate linearly. The workhorse of scientific computing is RK4 (Runge-Kutta 4th order), which evaluates the derivative four times per step and achieves O(Δt4) accuracy. Modern solvers like Dormand-Prince (dopri5) go further — they adaptively choose their step size based on local error estimates, effectively letting the ODE decide its own "depth."
Here's Euler and RK4 applied to a 2D spiral ODE. Watch how RK4 tracks the true solution far more faithfully:
import numpy as np
def spiral_ode(h, t):
"""dh/dt for a 2D spiral: rotation + mild contraction."""
x, y = h
return np.array([-0.1 * x - y, x - 0.1 * y])
def euler_solve(f, h0, t_span, n_steps):
"""Forward Euler: h_{k+1} = h_k + dt * f(h_k, t_k)."""
ts = np.linspace(t_span[0], t_span[1], n_steps + 1)
dt = ts[1] - ts[0]
traj = [h0]
for i in range(n_steps):
traj.append(traj[-1] + dt * f(traj[-1], ts[i]))
return ts, np.array(traj)
def rk4_solve(f, h0, t_span, n_steps):
"""Classic RK4: four evaluations per step, O(dt^4) accuracy."""
ts = np.linspace(t_span[0], t_span[1], n_steps + 1)
dt = ts[1] - ts[0]
traj = [h0]
for i in range(n_steps):
t, h = ts[i], traj[-1]
k1 = f(h, t)
k2 = f(h + 0.5 * dt * k1, t + 0.5 * dt)
k3 = f(h + 0.5 * dt * k2, t + 0.5 * dt)
k4 = f(h + dt * k3, t + dt)
traj.append(h + (dt / 6) * (k1 + 2*k2 + 2*k3 + k4))
return ts, np.array(traj)
h0 = np.array([2.0, 0.0])
_, euler_traj = euler_solve(spiral_ode, h0, (0, 6), n_steps=15)
_, rk4_traj = rk4_solve(spiral_ode, h0, (0, 6), n_steps=15)
# Euler spirals outward (unstable); RK4 hugs the true inward spiral
With only 15 steps, Euler drifts wildly while RK4 stays on track. This accuracy gap matters enormously when the ODE solver is your neural network.
2. From ResNets to Neural ODEs
In 2016, Kaiming He showed that residual connections — ht+1 = ht + fθ(ht) — let networks train hundreds of layers deep. Two years later, Chen et al. (2018) made the conceptual leap that earned a NeurIPS Best Paper: what if we take the limit?
As the number of layers T → ∞ and the step size Δt → 0, the residual update rule converges to a continuous ODE:
dh/dt = fθ(h(t), t)
The forward pass is now an IVP: start from input h(0) = x, solve the ODE to time T, and read off h(T) as the output. Any black-box ODE solver (Euler, RK4, dopri5) can be plugged in. A single network fθ defines the dynamics at every "layer" — it's shared across all of continuous time, giving massive parameter efficiency.
But the real revolution is in memory. Standard backpropagation through T layers requires storing all T intermediate activations — O(T) memory. Neural ODEs use the adjoint method to compute gradients with O(1) memory, regardless of how many solver steps were taken. We'll build that next.
import numpy as np
class NeuralODE:
"""A tiny Neural ODE: dh/dt = W2 @ tanh(W1 @ h + b1) + b2."""
def __init__(self, dim, hidden=16):
scale = np.sqrt(2 / dim)
self.W1 = np.random.randn(hidden, dim) * scale
self.b1 = np.zeros(hidden)
self.W2 = np.random.randn(dim, hidden) * scale
self.b2 = np.zeros(dim)
def f(self, h, t):
"""The learned dynamics f_theta(h, t)."""
z = np.tanh(self.W1 @ h + self.b1)
return self.W2 @ z + self.b2
def forward(self, h0, t0=0.0, t1=1.0, steps=20):
"""Solve the IVP from t0 to t1 using RK4."""
dt = (t1 - t0) / steps
h = h0.copy()
trajectory = [h.copy()]
for i in range(steps):
t = t0 + i * dt
k1 = self.f(h, t)
k2 = self.f(h + 0.5*dt*k1, t + 0.5*dt)
k3 = self.f(h + 0.5*dt*k2, t + 0.5*dt)
k4 = self.f(h + dt*k3, t + dt)
h = h + (dt / 6) * (k1 + 2*k2 + 2*k3 + k4)
trajectory.append(h.copy())
return h, trajectory
# The forward pass IS solving an ODE - no discrete layers
node = NeuralODE(dim=2)
h_final, path = node.forward(np.array([1.0, 0.5]))
print(f"Input: [1.0, 0.5] -> Output: [{h_final[0]:.3f}, {h_final[1]:.3f}]")
Notice what's absent: no layer count, no sequential module list. The network's "depth" is determined by the ODE solver at inference time. Feed it easy data and the adaptive solver takes fewer steps. Feed it hard data and it takes more. The network decides its own depth per input.
3. The Adjoint Method — Backpropagation in Continuous Time
Here's the memory problem: if we naively backpropagate through every RK4 step, we need to store all intermediate states — just like a deep ResNet. For a high-resolution ODE with thousands of steps, this is prohibitive.
The adjoint method solves this elegantly. Define the adjoint state a(t) = dL/dh(t) — how the loss changes with respect to the hidden state at time t. It satisfies its own ODE, running backward in time:
da/dt = −aT · (∂fθ/∂h)
Starting from the terminal condition a(T) = dL/dh(T), we integrate this ODE backwards to t = 0. Along the way, we accumulate the parameter gradients:
dL/dθ = −∫0T a(t)T (∂fθ/∂θ) dt
The beauty: we never store intermediate activations. We reconstruct h(t) on the fly by solving the forward ODE backward alongside the adjoint. Total memory: O(1), regardless of solver steps. This is actually Pontryagin's Maximum Principle from optimal control theory (1962) — the math community solved "backpropagation" sixty years before deep learning needed it.
def adjoint_gradients(node, h0, target, t0=0.0, t1=1.0, steps=20):
"""Compute dL/d(params) via the adjoint method (O(1) memory)."""
# Forward solve: store trajectory for reconstruction
h_final, trajectory = node.forward(h0, t0, t1, steps)
# Loss: L = 0.5 * ||h(T) - target||^2
loss = 0.5 * np.sum((h_final - target) ** 2)
# Terminal condition: a(T) = dL/dh(T)
a = h_final - target
# Backward solve: integrate adjoint ODE from T to 0
dt = (t1 - t0) / steps
grad_W1 = np.zeros_like(node.W1)
grad_W2 = np.zeros_like(node.W2)
for i in range(steps - 1, -1, -1):
h = trajectory[i]
z = np.tanh(node.W1 @ h + node.b1)
dz = 1 - z ** 2 # tanh derivative
# Jacobian: df/dh = W2 @ diag(dz) @ W1
df_dh = node.W2 @ np.diag(dz) @ node.W1
# Accumulate parameter gradients
grad_W2 += np.outer(a, z) * dt
grad_W1 += np.outer(node.W2.T @ a * dz, h) * dt
# Adjoint update (Euler step backward)
a = a + dt * (df_dh.T @ a)
return loss, {'W1': grad_W1, 'W2': grad_W2}
In practice, storing the full trajectory defeats the O(1) purpose. Production implementations (like torchdiffeq) reconstruct h(t) by solving the forward ODE backward alongside the adjoint — a single augmented ODE for [h, a, ∇θL] integrated from T to 0. Our simplified version stores the forward trajectory for clarity, but the gradient computation itself is identical.
4. Continuous Normalizing Flows
Normalizing flows transform a simple distribution (like a Gaussian) into a complex one through a sequence of invertible maps. Each step requires computing the determinant of the Jacobian — an O(d3) operation that severely constrains architecture choices.
Continuous Normalizing Flows (CNFs) replace discrete invertible steps with an ODE: dz/dt = fθ(z, t). The density changes according to the instantaneous change-of-variables formula (Liouville's equation):
d log p(z)/dt = −tr(∂fθ/∂z)
The trace is O(d), not O(d3). Even better, the Hutchinson trace estimator gives an unbiased estimate using random vector-Jacobian products — no need to form the full Jacobian at all. This is exactly what FFJORD (Grathwohl et al. 2019) exploits, and it directly connects to flow matching: Conditional Flow Matching learns the velocity field fθ for this very ODE, but with a simulation-free training objective.
def hutchinson_trace(f, h, t, dim, n_probes=5):
"""Estimate tr(df/dh) via Hutchinson's trick: E[v^T J v]."""
eps = 1e-5
trace_est = 0.0
for _ in range(n_probes):
v = np.random.randn(dim)
jvp = (f(h + eps * v, t) - f(h - eps * v, t)) / (2 * eps)
trace_est += np.dot(v, jvp)
return trace_est / n_probes
def cnf_forward(node, z0, t0=0.0, t1=1.0, steps=50):
"""Solve the CNF: dz/dt = f(z,t), d log p/dt = -tr(df/dz)."""
dt = (t1 - t0) / steps
z = z0.copy()
log_p_change = 0.0
for i in range(steps):
t = t0 + i * dt
dz = node.f(z, t)
trace = hutchinson_trace(node.f, z, t, len(z))
log_p_change -= trace * dt # Liouville's equation
z = z + dt * dz
return z, log_p_change
# log p(x) = log p(z0) + accumulated log-density change
# No invertibility constraint, no determinant computation!
The Hutchinson estimator uses random probe vectors v and computes vTJv via finite differences. In expectation, this equals the trace. With just 5 probes, the variance is usually low enough for training.
5. Augmented Neural ODEs — Breaking the Topology Barrier
Neural ODEs in 2D have a fundamental limitation rooted in the uniqueness theorem: ODE trajectories cannot cross. This means the transformation from h(0) to h(T) is a homeomorphism — it preserves topology. In two dimensions, concentric circles stay topologically nested no matter how you deform them. A standard 2D Neural ODE literally cannot make concentric circles linearly separable, just as a single-layer perceptron cannot solve XOR.
Augmented Neural ODEs (Dupont et al. 2019) fix this by lifting the input to a higher-dimensional space: [h; 0] ∈ ℝd+p. The extra dimensions give trajectories "room" to go around each other, much like lifting a 2D knot into 3D where it can be untied. This is exactly what ResNets do naturally — the channel dimension provides the augmentation that makes deep classification work.
class AugmentedNeuralODE:
"""Lift input to higher-dim space, then solve ODE."""
def __init__(self, data_dim, aug_dim=3, hidden=32):
total = data_dim + aug_dim
self.aug_dim = aug_dim
scale = np.sqrt(2 / total)
self.W1 = np.random.randn(hidden, total) * scale
self.b1 = np.zeros(hidden)
self.W2 = np.random.randn(total, hidden) * scale
self.b2 = np.zeros(total)
def f(self, h, t):
z = np.tanh(self.W1 @ h + self.b1)
return self.W2 @ z + self.b2
def forward(self, x, t1=1.0, steps=20):
# Augment: pad input with zeros
h = np.concatenate([x, np.zeros(self.aug_dim)])
dt = t1 / steps
for _ in range(steps):
k1 = self.f(h, 0)
k2 = self.f(h + 0.5*dt*k1, 0)
k3 = self.f(h + 0.5*dt*k2, 0)
k4 = self.f(h + dt*k3, 0)
h = h + (dt / 6) * (k1 + 2*k2 + 2*k3 + k4)
return h[:len(x)] # project back to data dimension
# Standard 2D NODE: can't untangle concentric circles
# Augmented NODE (2D + 3 extra dims): trajectories detour through 5D
# where they CAN pass around each other -> linearly separable output
Try It: Continuous Depth Classifier
Watch how points flow through continuous time. Standard NODE preserves topology (concentric circles stay nested). Augmented NODE lifts to higher dimensions, allowing separation.
6. Latent ODEs for Irregular Time Series
Most real-world time series are irregularly sampled: medical records have sparse checkups, IoT sensors report at variable rates, financial trades cluster during market hours. Standard RNNs assume fixed time steps. Resampling to a grid wastes computation on gaps and can distort the signal.
Latent ODEs (Rubanova et al. 2019) handle this naturally. The architecture has three parts: an ODE-RNN encoder that processes observations in reverse (the hidden state evolves via a Neural ODE between observations), a latent ODE that defines smooth dynamics in a learned latent space, and a decoder that maps latent states back to observation space.
The key insight: between observations, the hidden state follows smooth ODE dynamics. At each observation, new information updates the state. The ODE handles any spacing — 1 millisecond or 1 year — without resampling. Interpolation between observations and extrapolation beyond them emerge naturally from the continuous dynamics.
class LatentODE:
"""Simplified Latent ODE: encode -> latent dynamics -> decode."""
def __init__(self, obs_dim=1, latent_dim=4):
self.latent_dim = latent_dim
# Encoder: observations -> latent initial state
self.enc_W = np.random.randn(latent_dim, obs_dim + 1) * 0.5
# Dynamics: dz/dt = tanh(W_dyn @ z)
self.dyn_W = np.random.randn(latent_dim, latent_dim) * 0.3
# Decoder: latent state -> observation
self.dec_W = np.random.randn(obs_dim, latent_dim) * 0.5
def encode(self, obs_times, obs_values):
"""Simple mean encoding (real version uses ODE-RNN)."""
features = np.column_stack([obs_times, obs_values])
return np.tanh(self.enc_W @ features.mean(axis=0))
def latent_dynamics(self, z0, eval_times, steps_per_unit=10):
"""Solve latent ODE at arbitrary evaluation times."""
results = {}
z = z0.copy()
t_cur = 0.0
for t_target in sorted(eval_times):
n = max(1, int((t_target - t_cur) * steps_per_unit))
dt = (t_target - t_cur) / n
for _ in range(n):
z = z + dt * np.tanh(self.dyn_W @ z)
results[t_target] = z.copy()
t_cur = t_target
return results
def decode(self, z):
return self.dec_W @ z
# Irregular observations? Just solve the ODE between whatever times
# you have. No resampling, no fixed grid, no wasted computation.
Try It: Irregular Time Series with Latent ODE
Click anywhere on the canvas to add observations. Shift-click to remove the nearest one. The fitted curve and uncertainty band update automatically.
7. The Deep Learning – Dynamical Systems Connection
Neural ODEs reveal a deep bridge between deep learning and dynamical systems theory. Every ResNet block is an Euler step. Backpropagation is the adjoint method. Skip connections keep the dynamics near the identity (eigenvalues ≈ 1), which is exactly why they prevent vanishing gradients — the system stays at the edge of stability.
This perspective places modern architectures on a spectrum of continuous dynamics:
| Architecture | Dynamics Type | Memory | Key Property |
|---|---|---|---|
| ResNet | Discretized ODE (Δt = 1) | O(T) | Fixed depth, fast |
| Neural ODE | Nonlinear ODE | O(1) | Adaptive depth, continuous |
| SSM (S4/Mamba) | Linear ODE + input | O(1) | Parallelizable, long-range |
| Diffusion Model | Stochastic ODE/SDE | O(T) | Generation via reverse process |
| Flow Matching | CNF (learned velocity) | O(1) | Simulation-free training |
The spectral radius of ∂f/∂h tells the stability story: less than 1 means contracting dynamics (information loss), greater than 1 means expanding (gradient explosion), approximately 1 means the edge of chaos where useful computation happens. Neural ODEs make this analysis precise because the dynamics are continuous — no discrete jumps to worry about.
Looking forward, the Neural ODE framework keeps expanding. Neural SDEs add stochastic noise for uncertainty quantification. Neural CDEs (controlled differential equations) handle streaming data elegantly. And the connection to flow matching — where you train a CNF's velocity field with a simple regression loss — has become the backbone of modern generative models.
References & Further Reading
- Chen et al. — Neural Ordinary Differential Equations (NeurIPS 2018) — the foundational paper, NeurIPS Best Paper award
- Grathwohl et al. — FFJORD: Free-form Continuous Dynamics (ICLR 2019) — continuous normalizing flows with Hutchinson estimator
- Dupont et al. — Augmented Neural ODEs (NeurIPS 2019) — breaking the homeomorphism barrier
- Rubanova et al. — Latent ODEs for Irregularly-Sampled Time Series (NeurIPS 2019) — the "killer app" for medical/sensor data
- Kidger — On Neural Differential Equations (PhD thesis, 2022) — the most comprehensive reference on Neural DEs
- He et al. — Deep Residual Learning (CVPR 2016) — the ResNet paper that started the ODE connection
- Haber & Ruthotto — Stable Architectures for Deep Networks (2017) — ODE stability perspective on ResNet design
- torchdiffeq — PyTorch implementation of Neural ODE solvers by Chen et al.
- DiffEqML — differentiable ODE solvers for machine learning