← Back to Blog

Neural Processes from Scratch

1. The Problem: Fast Uncertainty from Context

Gaussian Processes give beautiful uncertainty estimates — they genuinely know what they don't know. Show a GP five data points and it'll give you a smooth prediction that passes through them, flanked by uncertainty bands that widen honestly in regions where it has no data. The problem? GPs require inverting an n×n matrix, making them O(n³) — fine for 100 points, painful for 1,000, and impossible for 100,000. They also need you to hand-design a kernel function that encodes your assumptions about the data's structure.

Neural networks sit at the opposite extreme. They handle any dataset size, learn representations automatically, and run inference in milliseconds. But ask a standard neural network "how confident are you?" and it'll shrug. It gives you a point prediction with no honest measure of uncertainty.

What if we could get both? A model that takes in a handful of observed points — a context set — and instantly outputs predictions with calibrated uncertainty at any new location, just like a GP, but using learned representations instead of hand-crafted kernels and running in O(n) instead of O(n³)?

That's exactly what Neural Processes do. Introduced by Garnelo et al. in 2018, they frame prediction as a meta-learning problem: given a context set C = {(xi, yi)} of observed input-output pairs and a new target input x*, predict the distribution p(y* | x*, C). The model learns across thousands of different functions during training, so at test time it can adapt to a completely new function from just a few context points — no retraining, no matrix inversion, no kernel engineering.

2. The Conditional Neural Process (CNP)

The simplest member of the Neural Process family is the Conditional Neural Process (Garnelo et al., 2018). Its architecture has three clean components:

The training loss is the negative log-likelihood of the target points given the context-derived representation:

Loss = ∑t [(yt − μt)² / (2σt²) + log(σt)]

Here's a minimal CNP implementation that trains on sine functions with varying amplitude and phase:

import numpy as np

class CNP:
    def __init__(self, h_dim=64):
        # Encoder: (x, y) -> r_i
        self.W1 = np.random.randn(2, h_dim) * 0.1
        self.b1 = np.zeros(h_dim)
        self.W2 = np.random.randn(h_dim, h_dim) * 0.1
        self.b2 = np.zeros(h_dim)
        # Decoder: (r, x*) -> (mu, log_sigma)
        self.W3 = np.random.randn(h_dim + 1, h_dim) * 0.1
        self.b3 = np.zeros(h_dim)
        self.W4 = np.random.randn(h_dim, 2) * 0.1
        self.b4 = np.zeros(2)

    def encode(self, x_ctx, y_ctx):
        """Encode each context pair, then mean-pool."""
        pairs = np.column_stack([x_ctx, y_ctx])        # (n, 2)
        h = np.maximum(0, pairs @ self.W1 + self.b1)   # ReLU
        r_all = np.maximum(0, h @ self.W2 + self.b2)   # (n, h_dim)
        return r_all.mean(axis=0)                       # mean pool

    def decode(self, r, x_target):
        """Predict mu, sigma at each target x."""
        inp = np.column_stack([np.tile(r, (len(x_target), 1)),
                               x_target])               # (m, h_dim+1)
        h = np.maximum(0, inp @ self.W3 + self.b3)
        out = h @ self.W4 + self.b4                      # (m, 2)
        mu = out[:, 0]
        sigma = np.exp(out[:, 1]) + 1e-4                 # positive via exp
        return mu, sigma

    def predict(self, x_ctx, y_ctx, x_target):
        r = self.encode(x_ctx, y_ctx)
        return self.decode(r, x_target)

# Generate a random sine task: y = A*sin(x + phase)
rng = np.random.default_rng(42)
A, phase = rng.uniform(0.5, 2.0), rng.uniform(0, 2 * np.pi)
x_all = rng.uniform(-3, 3, size=(20, 1))
y_all = A * np.sin(x_all + phase)

# Split into 5 context + 15 target
x_ctx, y_ctx = x_all[:5], y_all[:5]
x_tgt, y_tgt = x_all[5:], y_all[5:]

cnp = CNP()
mu, sigma = cnp.predict(x_ctx, y_ctx, x_tgt)
print(f"Predictions at first 3 targets:")
for i in range(3):
    print(f"  x={x_tgt[i,0]:.2f}  true={y_tgt[i,0]:.2f}"
          f"  pred={mu[i]:.2f} +/- {sigma[i]:.2f}")

The output from an untrained CNP will be random, but the architecture is correct: encode each context pair, mean-pool, decode at target locations. Training adjusts the weights so that the model learns to interpolate through context points with appropriate uncertainty.

3. Why Permutation Invariance Matters

The context set is a set, not a sequence. If someone shows you the points {(0, 1), (2, 3), (4, 2)}, the prediction at x=3 should be the same regardless of whether the points were presented in that order or as {(4, 2), (0, 1), (2, 3)}. This is permutation invariance.

The Deep Sets theorem (Zaheer et al., 2017) tells us that any permutation-invariant function on sets can be decomposed as ρ(∑ φ(xi)) — encode each element independently with φ, sum (or mean) the results, then post-process with ρ. This is exactly the CNP's encoder-aggregator-decoder pattern.

Mean pooling is the simplest permutation-invariant aggregation. It guarantees that shuffling the context produces identical predictions:

def verify_permutation_invariance(cnp, x_ctx, y_ctx, x_target):
    """Show that shuffling context gives identical predictions."""
    mu1, sigma1 = cnp.predict(x_ctx, y_ctx, x_target)

    # Shuffle context points
    perm = np.random.permutation(len(x_ctx))
    x_shuf, y_shuf = x_ctx[perm], y_ctx[perm]
    mu2, sigma2 = cnp.predict(x_shuf, y_shuf, x_target)

    print(f"Original order  mu[:3]: {mu1[:3].round(4)}")
    print(f"Shuffled order  mu[:3]: {mu2[:3].round(4)}")
    print(f"Max difference: {np.max(np.abs(mu1 - mu2)):.2e}")
    # Max difference: 0.00e+00

verify_permutation_invariance(cnp, x_ctx, y_ctx, x_tgt)

An RNN encoder, by contrast, would produce different hidden states for different orderings — violating the set semantics. This is why the encode-then-aggregate design is fundamental to Neural Processes.

4. The Underfitting Problem

The CNP has a critical limitation that becomes obvious when you look at its predictions closely: the mean prediction often doesn't pass through the context points. Even with perfectly noise-free observations, the CNP's predicted mean can miss the data it was given.

Why? Because the decoder produces each target prediction independently. It gets the same global representation r regardless of which target x* it's predicting at. There's no mechanism forcing the model to be consistent across predictions — and crucially, no mechanism forcing predictions at context locations to match the observed values.

This independence creates two problems:

These limitations motivate the latent Neural Process.

5. The Latent Neural Process — Adding Global Uncertainty

The fix is elegant: add a global latent variable z that captures function-level properties. Instead of collapsing all context information into a deterministic representation r, the latent NP encodes context into a distribution over z — a Gaussian with learned mean μz and standard deviation σz. Sampling different values of z produces different plausible functions, all consistent with the context.

The architecture now has two paths:

The decoder conditions on both: g(r, z, x*) → (μ*, σ*). The latent z carries global information (amplitude, frequency, smoothness) while the deterministic r carries local context details.

Training uses the Evidence Lower Bound (ELBO), the same variational objective used in VAEs:

ELBO = Eq(z|C,T)[log p(yT | xT, z, r)] − KL(q(z|C,T) || q(z|C))

The first term says: "z should help predict the targets." The second term says: "the posterior over z when you see both context AND targets shouldn't be too different from the posterior using context alone" — which encourages the model to extract useful information from context without relying on peeking at targets. We use the reparameterization trick to backpropagate through the sampling step: z = μz + σz ⊙ ε, where ε ~ N(0, I).

class LatentNP:
    def __init__(self, h_dim=64, z_dim=16):
        self.h_dim, self.z_dim = h_dim, z_dim
        # Shared encoder: (x, y) -> hidden
        self.enc_W1 = np.random.randn(2, h_dim) * 0.1
        self.enc_b1 = np.zeros(h_dim)
        # Deterministic path: hidden -> r
        self.det_W = np.random.randn(h_dim, h_dim) * 0.1
        self.det_b = np.zeros(h_dim)
        # Latent path: hidden -> (mu_z, log_sigma_z)
        self.lat_W = np.random.randn(h_dim, z_dim * 2) * 0.1
        self.lat_b = np.zeros(z_dim * 2)
        # Decoder: (r, z, x*) -> (mu, log_sigma)
        self.dec_W1 = np.random.randn(h_dim + z_dim + 1, h_dim) * 0.1
        self.dec_b1 = np.zeros(h_dim)
        self.dec_W2 = np.random.randn(h_dim, 2) * 0.1
        self.dec_b2 = np.zeros(2)

    def encode_context(self, x_ctx, y_ctx):
        pairs = np.column_stack([x_ctx, y_ctx])
        h = np.maximum(0, pairs @ self.enc_W1 + self.enc_b1)
        # Deterministic representation
        r = np.maximum(0, h @ self.det_W + self.det_b).mean(axis=0)
        # Latent distribution parameters
        lat = (h @ self.lat_W + self.lat_b).mean(axis=0)
        mu_z = lat[:self.z_dim]
        sigma_z = np.exp(lat[self.z_dim:]) + 1e-4
        return r, mu_z, sigma_z

    def sample_z(self, mu_z, sigma_z, rng):
        """Reparameterization trick: z = mu + sigma * epsilon."""
        eps = rng.standard_normal(self.z_dim)
        return mu_z + sigma_z * eps

    def decode(self, r, z, x_target):
        n = len(x_target)
        inp = np.column_stack([np.tile(r, (n, 1)),
                               np.tile(z, (n, 1)), x_target])
        h = np.maximum(0, inp @ self.dec_W1 + self.dec_b1)
        out = h @ self.dec_W2 + self.dec_b2
        return out[:, 0], np.exp(out[:, 1]) + 1e-4

    def predict_samples(self, x_ctx, y_ctx, x_target, n_samples=5):
        """Draw multiple coherent function samples."""
        r, mu_z, sigma_z = self.encode_context(x_ctx, y_ctx)
        rng = np.random.default_rng(0)
        samples = []
        for _ in range(n_samples):
            z = self.sample_z(mu_z, sigma_z, rng)
            mu, _ = self.decode(r, z, x_target)
            samples.append(mu)
        return np.array(samples)  # (n_samples, n_targets)

lnp = LatentNP()
samples = lnp.predict_samples(x_ctx, y_ctx, x_tgt, n_samples=5)
print(f"5 function samples, each at {samples.shape[1]} targets")
print(f"Sample std across functions: {samples.std(axis=0).mean():.3f}")
# Different z values produce different but coherent predictions

The key difference from the CNP: each sample is a coherent function — smooth and globally consistent, not independently sampled noise. Different z values produce functions with different amplitudes, frequencies, or offsets, but each individual sample is a plausible function that respects the context.

6. The Attentive Neural Process

The CNP and latent NP both use mean pooling to summarize context. This means every context point contributes equally to the representation, regardless of the target query location. But when predicting at x* = 5.0, a context point at x = 4.9 should matter far more than one at x = 0.1.

The Attentive Neural Process (Kim et al., 2019) replaces mean pooling with cross-attention: each target query attends to all context points, weighting nearby or relevant ones more heavily. The math follows the standard attention mechanism from transformers:

attention(Q, K, V) = softmax(Q · KT / √dk) · V

Queries come from target inputs, keys and values come from context representations. Each target gets its own query-specific context summary rather than a single global vector.

class AttentiveNP:
    def __init__(self, h_dim=64, n_heads=4):
        self.h_dim = h_dim
        self.d_k = h_dim // n_heads
        self.n_heads = n_heads
        # Encoder
        self.enc_W1 = np.random.randn(2, h_dim) * 0.1
        self.enc_b1 = np.zeros(h_dim)
        # Attention projections (per-head)
        self.W_Q = np.random.randn(1, h_dim) * 0.1   # query from x*
        self.W_K = np.random.randn(h_dim, h_dim) * 0.1  # key from r_i
        self.W_V = np.random.randn(h_dim, h_dim) * 0.1  # value from r_i
        # Decoder
        self.dec_W1 = np.random.randn(h_dim + 1, h_dim) * 0.1
        self.dec_b1 = np.zeros(h_dim)
        self.dec_W2 = np.random.randn(h_dim, 2) * 0.1
        self.dec_b2 = np.zeros(2)

    def cross_attend(self, x_target, ctx_repr):
        """Each target attends to context representations."""
        Q = x_target @ self.W_Q          # (m, h_dim)
        K = ctx_repr @ self.W_K           # (n, h_dim)
        V = ctx_repr @ self.W_V           # (n, h_dim)
        # Scaled dot-product attention
        scores = Q @ K.T / np.sqrt(self.d_k)  # (m, n)
        weights = np.exp(scores - scores.max(axis=1, keepdims=True))
        weights /= weights.sum(axis=1, keepdims=True)  # softmax
        return weights @ V                 # (m, h_dim)

    def predict(self, x_ctx, y_ctx, x_target):
        pairs = np.column_stack([x_ctx, y_ctx])
        ctx_repr = np.maximum(0, pairs @ self.enc_W1 + self.enc_b1)
        attended = self.cross_attend(x_target, ctx_repr)  # (m, h_dim)
        inp = np.column_stack([attended, x_target])
        h = np.maximum(0, inp @ self.dec_W1 + self.dec_b1)
        out = h @ self.dec_W2 + self.dec_b2
        return out[:, 0], np.exp(out[:, 1]) + 1e-4

anp = AttentiveNP()
mu, sigma = anp.predict(x_ctx, y_ctx, x_tgt)
print(f"ANP predictions at first 3 targets:")
for i in range(3):
    print(f"  x={x_tgt[i,0]:.2f}  pred={mu[i]:.2f} +/- {sigma[i]:.2f}")

The cross-attention mechanism gives the ANP a crucial advantage: predictions near context clusters are sharp and accurate (high attention weights on nearby context), while predictions far from any context point have wide uncertainty (attention is spread diffusely). This is exactly the behavior we want — and exactly what GPs give naturally through their posterior variance.

7. Training Neural Processes — The Meta-Learning Loop

Neural Processes are trained with episodic meta-learning. Each training episode presents the model with a different function — one might be a gentle sine wave, the next a steep quadratic, the next a wiggly polynomial. For each function, we split the observed points into context (what the model sees) and target (what it must predict). Over thousands of episodes, the model learns not to memorize any single function, but to rapidly adapt to new functions from a handful of context points.

def train_cnp(cnp, n_episodes=2000, lr=1e-3):
    """Episodic meta-learning training loop."""
    rng = np.random.default_rng(42)
    for ep in range(n_episodes):
        # Step 1: Sample a random function (sine with random params)
        amp = rng.uniform(0.5, 2.0)
        phase = rng.uniform(0, 2 * np.pi)
        freq = rng.uniform(0.5, 2.0)
        f = lambda x: amp * np.sin(freq * x + phase)

        # Step 2: Sample points from this function
        x_all = rng.uniform(-4, 4, size=(20, 1))
        y_all = f(x_all) + rng.normal(0, 0.05, size=x_all.shape)

        # Step 3: Random context/target split
        n_ctx = rng.integers(3, 15)
        idx = rng.permutation(20)
        x_ctx, y_ctx = x_all[idx[:n_ctx]], y_all[idx[:n_ctx]]
        x_tgt, y_tgt = x_all[idx[n_ctx:]], y_all[idx[n_ctx:]]

        # Step 4: Forward pass + NLL loss
        mu, sigma = cnp.predict(x_ctx, y_ctx, x_tgt)
        nll = ((y_tgt.ravel() - mu)**2 / (2 * sigma**2)
               + np.log(sigma)).mean()

        # Step 5: Update (in practice, use autograd + Adam)
        # Here we just track the loss for demonstration
        if ep % 500 == 0:
            print(f"Episode {ep:4d}  NLL: {nll:.3f}")

train_cnp(cnp)

The key insight: by varying the number of context points randomly (between 3 and 15 here), the model learns to make reasonable predictions with little data AND to sharpen its predictions when given more data — just like a GP that becomes more confident as it observes more points. The variable-size context also prevents the model from overfitting to any particular context size.

8. NPs vs GPs — When to Use Which

Feature GP CNP Latent NP ANP
Test-time cost O(n³) O(n) O(n) O(n·m)
Kernels Hand-designed Learned Learned Learned
Uncertainty Exact Approximate Approximate Approximate
Coherent samples Yes No Yes Yes
Passes through context Yes Often not Better Best

Here's a side-by-side comparison, using a GP with an RBF kernel alongside a trained NP on the same context set:

def gp_predict(x_ctx, y_ctx, x_test, l=1.0, sigma_n=0.1):
    """Standard GP posterior with RBF kernel."""
    def rbf(x1, x2):
        return np.exp(-0.5 * ((x1 - x2.T) / l) ** 2)

    K = rbf(x_ctx, x_ctx) + sigma_n**2 * np.eye(len(x_ctx))
    K_s = rbf(x_ctx, x_test)
    K_ss = rbf(x_test, x_test)
    L = np.linalg.cholesky(K)
    alpha = np.linalg.solve(L.T, np.linalg.solve(L, y_ctx))
    mu = K_s.T @ alpha
    v = np.linalg.solve(L, K_s)
    sigma = np.sqrt(np.diag(K_ss - v.T @ v).clip(0))
    return mu.ravel(), sigma

import time
x_test = np.linspace(-4, 4, 200).reshape(-1, 1)

# GP inference (exact, O(n^3))
t0 = time.perf_counter()
gp_mu, gp_sigma = gp_predict(x_ctx, y_ctx, x_test)
gp_time = (time.perf_counter() - t0) * 1000

# NP inference (approximate, O(n))
t0 = time.perf_counter()
np_mu, np_sigma = cnp.predict(x_ctx, y_ctx, x_test)
np_time = (time.perf_counter() - t0) * 1000

print(f"GP: {gp_time:.1f}ms | NP: {np_time:.1f}ms")
print(f"GP uncertainty range: [{gp_sigma.min():.3f}, {gp_sigma.max():.3f}]")
print(f"NP uncertainty range: [{np_sigma.min():.3f}, {np_sigma.max():.3f}]")

Use GPs when you have a small dataset (<1,000 points), know what kernel to use, and need exact uncertainty. Use NPs when you need fast adaptation to new tasks, have complex structured inputs (images, graphs), or need to scale to large context sets. NPs shine in scenarios where you face many related but different prediction tasks — weather at different stations, drug response across patients, completion of partially observed images.

Try It: Neural Process Playground

Click on the canvas to add context points. The Neural Process predicts the function with uncertainty bands. Toggle between CNP and Latent NP to see coherent vs independent samples.

Context: 0 points

Try It: Context Size Explorer

Drag the slider to add more context points and watch both GP and NP predictions sharpen. Notice the timing difference as context grows.

9. Conclusion

Neural Processes sit at a beautiful intersection: the principled uncertainty of Gaussian Processes meets the scalability and learned representations of neural networks, wrapped in a meta-learning training procedure. The progression tells a clean story — CNPs show that the basic idea works but suffer from independent predictions and underfitting; latent NPs fix coherence by adding a global latent variable; attentive NPs fix underfitting by replacing pooling with cross-attention.

The NP family continues to evolve. Convolutional CNPs (Gordon et al., 2020) add translation equivariance for spatial data. Transformer Neural Processes replace the entire architecture with transformer blocks. Gaussian Neural Processes enforce exact GP behavior in the limit of infinite context. Each variant pushes the boundary of what "learning to be a GP" can mean.

The core insight, though, remains the same: if you train a neural network to predict from context sets across thousands of different functions, it learns to do something remarkable — it learns the process of going from observations to uncertainty-aware predictions. Not a single function, but the mapping itself.

References & Further Reading