← Back to Blog

Neural ODEs from Scratch

1. Ordinary Differential Equations from First Principles

A ResNet computes h1 = h0 + f(h0), then h2 = h1 + f(h1), and so on through dozens of layers. Stare at that long enough and something clicks: this is Euler's method for numerically solving the differential equation dh/dt = f(h(t)), with a step size of 1. The forward pass of a deep residual network is an ODE solver. Neural ODEs take this observation seriously.

An ordinary differential equation describes how a state h evolves over a continuous variable t: dh/dt = f(h, t). Given an initial condition h(0) = h0, the solution traces a trajectory through state space. This is called an initial value problem (IVP).

The simplest numerical solver is Euler's method: advance by a small step Δt using the current derivative. It's fast but inaccurate — errors accumulate linearly. The workhorse of scientific computing is RK4 (Runge-Kutta 4th order), which evaluates the derivative four times per step and achieves O(Δt4) accuracy. Modern solvers like Dormand-Prince (dopri5) go further — they adaptively choose their step size based on local error estimates, effectively letting the ODE decide its own "depth."

Here's Euler and RK4 applied to a 2D spiral ODE. Watch how RK4 tracks the true solution far more faithfully:

import numpy as np

def spiral_ode(h, t):
    """dh/dt for a 2D spiral: rotation + mild contraction."""
    x, y = h
    return np.array([-0.1 * x - y, x - 0.1 * y])

def euler_solve(f, h0, t_span, n_steps):
    """Forward Euler: h_{k+1} = h_k + dt * f(h_k, t_k)."""
    ts = np.linspace(t_span[0], t_span[1], n_steps + 1)
    dt = ts[1] - ts[0]
    traj = [h0]
    for i in range(n_steps):
        traj.append(traj[-1] + dt * f(traj[-1], ts[i]))
    return ts, np.array(traj)

def rk4_solve(f, h0, t_span, n_steps):
    """Classic RK4: four evaluations per step, O(dt^4) accuracy."""
    ts = np.linspace(t_span[0], t_span[1], n_steps + 1)
    dt = ts[1] - ts[0]
    traj = [h0]
    for i in range(n_steps):
        t, h = ts[i], traj[-1]
        k1 = f(h, t)
        k2 = f(h + 0.5 * dt * k1, t + 0.5 * dt)
        k3 = f(h + 0.5 * dt * k2, t + 0.5 * dt)
        k4 = f(h + dt * k3, t + dt)
        traj.append(h + (dt / 6) * (k1 + 2*k2 + 2*k3 + k4))
    return ts, np.array(traj)

h0 = np.array([2.0, 0.0])
_, euler_traj = euler_solve(spiral_ode, h0, (0, 6), n_steps=15)
_, rk4_traj   = rk4_solve(spiral_ode, h0, (0, 6), n_steps=15)
# Euler spirals outward (unstable); RK4 hugs the true inward spiral

With only 15 steps, Euler drifts wildly while RK4 stays on track. This accuracy gap matters enormously when the ODE solver is your neural network.

2. From ResNets to Neural ODEs

In 2016, Kaiming He showed that residual connections — ht+1 = ht + fθ(ht) — let networks train hundreds of layers deep. Two years later, Chen et al. (2018) made the conceptual leap that earned a NeurIPS Best Paper: what if we take the limit?

As the number of layers T → ∞ and the step size Δt → 0, the residual update rule converges to a continuous ODE:

dh/dt = fθ(h(t), t)

The forward pass is now an IVP: start from input h(0) = x, solve the ODE to time T, and read off h(T) as the output. Any black-box ODE solver (Euler, RK4, dopri5) can be plugged in. A single network fθ defines the dynamics at every "layer" — it's shared across all of continuous time, giving massive parameter efficiency.

But the real revolution is in memory. Standard backpropagation through T layers requires storing all T intermediate activations — O(T) memory. Neural ODEs use the adjoint method to compute gradients with O(1) memory, regardless of how many solver steps were taken. We'll build that next.

import numpy as np

class NeuralODE:
    """A tiny Neural ODE: dh/dt = W2 @ tanh(W1 @ h + b1) + b2."""
    def __init__(self, dim, hidden=16):
        scale = np.sqrt(2 / dim)
        self.W1 = np.random.randn(hidden, dim) * scale
        self.b1 = np.zeros(hidden)
        self.W2 = np.random.randn(dim, hidden) * scale
        self.b2 = np.zeros(dim)

    def f(self, h, t):
        """The learned dynamics f_theta(h, t)."""
        z = np.tanh(self.W1 @ h + self.b1)
        return self.W2 @ z + self.b2

    def forward(self, h0, t0=0.0, t1=1.0, steps=20):
        """Solve the IVP from t0 to t1 using RK4."""
        dt = (t1 - t0) / steps
        h = h0.copy()
        trajectory = [h.copy()]
        for i in range(steps):
            t = t0 + i * dt
            k1 = self.f(h, t)
            k2 = self.f(h + 0.5*dt*k1, t + 0.5*dt)
            k3 = self.f(h + 0.5*dt*k2, t + 0.5*dt)
            k4 = self.f(h + dt*k3, t + dt)
            h = h + (dt / 6) * (k1 + 2*k2 + 2*k3 + k4)
            trajectory.append(h.copy())
        return h, trajectory

# The forward pass IS solving an ODE - no discrete layers
node = NeuralODE(dim=2)
h_final, path = node.forward(np.array([1.0, 0.5]))
print(f"Input: [1.0, 0.5] -> Output: [{h_final[0]:.3f}, {h_final[1]:.3f}]")

Notice what's absent: no layer count, no sequential module list. The network's "depth" is determined by the ODE solver at inference time. Feed it easy data and the adaptive solver takes fewer steps. Feed it hard data and it takes more. The network decides its own depth per input.

3. The Adjoint Method — Backpropagation in Continuous Time

Here's the memory problem: if we naively backpropagate through every RK4 step, we need to store all intermediate states — just like a deep ResNet. For a high-resolution ODE with thousands of steps, this is prohibitive.

The adjoint method solves this elegantly. Define the adjoint state a(t) = dL/dh(t) — how the loss changes with respect to the hidden state at time t. It satisfies its own ODE, running backward in time:

da/dt = −aT · (∂fθ/∂h)

Starting from the terminal condition a(T) = dL/dh(T), we integrate this ODE backwards to t = 0. Along the way, we accumulate the parameter gradients:

dL/dθ = −∫0T a(t)T (∂fθ/∂θ) dt

The beauty: we never store intermediate activations. We reconstruct h(t) on the fly by solving the forward ODE backward alongside the adjoint. Total memory: O(1), regardless of solver steps. This is actually Pontryagin's Maximum Principle from optimal control theory (1962) — the math community solved "backpropagation" sixty years before deep learning needed it.

def adjoint_gradients(node, h0, target, t0=0.0, t1=1.0, steps=20):
    """Compute dL/d(params) via the adjoint method (O(1) memory)."""
    # Forward solve: store trajectory for reconstruction
    h_final, trajectory = node.forward(h0, t0, t1, steps)

    # Loss: L = 0.5 * ||h(T) - target||^2
    loss = 0.5 * np.sum((h_final - target) ** 2)

    # Terminal condition: a(T) = dL/dh(T)
    a = h_final - target

    # Backward solve: integrate adjoint ODE from T to 0
    dt = (t1 - t0) / steps
    grad_W1 = np.zeros_like(node.W1)
    grad_W2 = np.zeros_like(node.W2)

    for i in range(steps - 1, -1, -1):
        h = trajectory[i]
        z = np.tanh(node.W1 @ h + node.b1)
        dz = 1 - z ** 2  # tanh derivative

        # Jacobian: df/dh = W2 @ diag(dz) @ W1
        df_dh = node.W2 @ np.diag(dz) @ node.W1

        # Accumulate parameter gradients
        grad_W2 += np.outer(a, z) * dt
        grad_W1 += np.outer(node.W2.T @ a * dz, h) * dt

        # Adjoint update (Euler step backward)
        a = a + dt * (df_dh.T @ a)

    return loss, {'W1': grad_W1, 'W2': grad_W2}

In practice, storing the full trajectory defeats the O(1) purpose. Production implementations (like torchdiffeq) reconstruct h(t) by solving the forward ODE backward alongside the adjoint — a single augmented ODE for [h, a, ∇θL] integrated from T to 0. Our simplified version stores the forward trajectory for clarity, but the gradient computation itself is identical.

4. Continuous Normalizing Flows

Normalizing flows transform a simple distribution (like a Gaussian) into a complex one through a sequence of invertible maps. Each step requires computing the determinant of the Jacobian — an O(d3) operation that severely constrains architecture choices.

Continuous Normalizing Flows (CNFs) replace discrete invertible steps with an ODE: dz/dt = fθ(z, t). The density changes according to the instantaneous change-of-variables formula (Liouville's equation):

d log p(z)/dt = −tr(∂fθ/∂z)

The trace is O(d), not O(d3). Even better, the Hutchinson trace estimator gives an unbiased estimate using random vector-Jacobian products — no need to form the full Jacobian at all. This is exactly what FFJORD (Grathwohl et al. 2019) exploits, and it directly connects to flow matching: Conditional Flow Matching learns the velocity field fθ for this very ODE, but with a simulation-free training objective.

def hutchinson_trace(f, h, t, dim, n_probes=5):
    """Estimate tr(df/dh) via Hutchinson's trick: E[v^T J v]."""
    eps = 1e-5
    trace_est = 0.0
    for _ in range(n_probes):
        v = np.random.randn(dim)
        jvp = (f(h + eps * v, t) - f(h - eps * v, t)) / (2 * eps)
        trace_est += np.dot(v, jvp)
    return trace_est / n_probes

def cnf_forward(node, z0, t0=0.0, t1=1.0, steps=50):
    """Solve the CNF: dz/dt = f(z,t), d log p/dt = -tr(df/dz)."""
    dt = (t1 - t0) / steps
    z = z0.copy()
    log_p_change = 0.0

    for i in range(steps):
        t = t0 + i * dt
        dz = node.f(z, t)
        trace = hutchinson_trace(node.f, z, t, len(z))
        log_p_change -= trace * dt  # Liouville's equation
        z = z + dt * dz

    return z, log_p_change

# log p(x) = log p(z0) + accumulated log-density change
# No invertibility constraint, no determinant computation!

The Hutchinson estimator uses random probe vectors v and computes vTJv via finite differences. In expectation, this equals the trace. With just 5 probes, the variance is usually low enough for training.

5. Augmented Neural ODEs — Breaking the Topology Barrier

Neural ODEs in 2D have a fundamental limitation rooted in the uniqueness theorem: ODE trajectories cannot cross. This means the transformation from h(0) to h(T) is a homeomorphism — it preserves topology. In two dimensions, concentric circles stay topologically nested no matter how you deform them. A standard 2D Neural ODE literally cannot make concentric circles linearly separable, just as a single-layer perceptron cannot solve XOR.

Augmented Neural ODEs (Dupont et al. 2019) fix this by lifting the input to a higher-dimensional space: [h; 0] ∈ ℝd+p. The extra dimensions give trajectories "room" to go around each other, much like lifting a 2D knot into 3D where it can be untied. This is exactly what ResNets do naturally — the channel dimension provides the augmentation that makes deep classification work.

class AugmentedNeuralODE:
    """Lift input to higher-dim space, then solve ODE."""
    def __init__(self, data_dim, aug_dim=3, hidden=32):
        total = data_dim + aug_dim
        self.aug_dim = aug_dim
        scale = np.sqrt(2 / total)
        self.W1 = np.random.randn(hidden, total) * scale
        self.b1 = np.zeros(hidden)
        self.W2 = np.random.randn(total, hidden) * scale
        self.b2 = np.zeros(total)

    def f(self, h, t):
        z = np.tanh(self.W1 @ h + self.b1)
        return self.W2 @ z + self.b2

    def forward(self, x, t1=1.0, steps=20):
        # Augment: pad input with zeros
        h = np.concatenate([x, np.zeros(self.aug_dim)])
        dt = t1 / steps
        for _ in range(steps):
            k1 = self.f(h, 0)
            k2 = self.f(h + 0.5*dt*k1, 0)
            k3 = self.f(h + 0.5*dt*k2, 0)
            k4 = self.f(h + dt*k3, 0)
            h = h + (dt / 6) * (k1 + 2*k2 + 2*k3 + k4)
        return h[:len(x)]  # project back to data dimension

# Standard 2D NODE: can't untangle concentric circles
# Augmented NODE (2D + 3 extra dims): trajectories detour through 5D
# where they CAN pass around each other -> linearly separable output

Try It: Continuous Depth Classifier

Watch how points flow through continuous time. Standard NODE preserves topology (concentric circles stay nested). Augmented NODE lifts to higher dimensions, allowing separation.

t = 0.00  |  Linear accuracy: 50%

6. Latent ODEs for Irregular Time Series

Most real-world time series are irregularly sampled: medical records have sparse checkups, IoT sensors report at variable rates, financial trades cluster during market hours. Standard RNNs assume fixed time steps. Resampling to a grid wastes computation on gaps and can distort the signal.

Latent ODEs (Rubanova et al. 2019) handle this naturally. The architecture has three parts: an ODE-RNN encoder that processes observations in reverse (the hidden state evolves via a Neural ODE between observations), a latent ODE that defines smooth dynamics in a learned latent space, and a decoder that maps latent states back to observation space.

The key insight: between observations, the hidden state follows smooth ODE dynamics. At each observation, new information updates the state. The ODE handles any spacing — 1 millisecond or 1 year — without resampling. Interpolation between observations and extrapolation beyond them emerge naturally from the continuous dynamics.

class LatentODE:
    """Simplified Latent ODE: encode -> latent dynamics -> decode."""
    def __init__(self, obs_dim=1, latent_dim=4):
        self.latent_dim = latent_dim
        # Encoder: observations -> latent initial state
        self.enc_W = np.random.randn(latent_dim, obs_dim + 1) * 0.5
        # Dynamics: dz/dt = tanh(W_dyn @ z)
        self.dyn_W = np.random.randn(latent_dim, latent_dim) * 0.3
        # Decoder: latent state -> observation
        self.dec_W = np.random.randn(obs_dim, latent_dim) * 0.5

    def encode(self, obs_times, obs_values):
        """Simple mean encoding (real version uses ODE-RNN)."""
        features = np.column_stack([obs_times, obs_values])
        return np.tanh(self.enc_W @ features.mean(axis=0))

    def latent_dynamics(self, z0, eval_times, steps_per_unit=10):
        """Solve latent ODE at arbitrary evaluation times."""
        results = {}
        z = z0.copy()
        t_cur = 0.0
        for t_target in sorted(eval_times):
            n = max(1, int((t_target - t_cur) * steps_per_unit))
            dt = (t_target - t_cur) / n
            for _ in range(n):
                z = z + dt * np.tanh(self.dyn_W @ z)
            results[t_target] = z.copy()
            t_cur = t_target
        return results

    def decode(self, z):
        return self.dec_W @ z

# Irregular observations? Just solve the ODE between whatever times
# you have. No resampling, no fixed grid, no wasted computation.

Try It: Irregular Time Series with Latent ODE

Click anywhere on the canvas to add observations. Shift-click to remove the nearest one. The fitted curve and uncertainty band update automatically.

Observations: 8  |  Click to add, Shift-click to remove

7. The Deep Learning – Dynamical Systems Connection

Neural ODEs reveal a deep bridge between deep learning and dynamical systems theory. Every ResNet block is an Euler step. Backpropagation is the adjoint method. Skip connections keep the dynamics near the identity (eigenvalues ≈ 1), which is exactly why they prevent vanishing gradients — the system stays at the edge of stability.

This perspective places modern architectures on a spectrum of continuous dynamics:

Architecture Dynamics Type Memory Key Property
ResNet Discretized ODE (Δt = 1) O(T) Fixed depth, fast
Neural ODE Nonlinear ODE O(1) Adaptive depth, continuous
SSM (S4/Mamba) Linear ODE + input O(1) Parallelizable, long-range
Diffusion Model Stochastic ODE/SDE O(T) Generation via reverse process
Flow Matching CNF (learned velocity) O(1) Simulation-free training

The spectral radius of ∂f/∂h tells the stability story: less than 1 means contracting dynamics (information loss), greater than 1 means expanding (gradient explosion), approximately 1 means the edge of chaos where useful computation happens. Neural ODEs make this analysis precise because the dynamics are continuous — no discrete jumps to worry about.

Looking forward, the Neural ODE framework keeps expanding. Neural SDEs add stochastic noise for uncertainty quantification. Neural CDEs (controlled differential equations) handle streaming data elegantly. And the connection to flow matching — where you train a CNF's velocity field with a simple regression loss — has become the backbone of modern generative models.

References & Further Reading