Neural Processes from Scratch
1. The Problem: Fast Uncertainty from Context
Gaussian Processes give beautiful uncertainty estimates — they genuinely know what they don't know. Show a GP five data points and it'll give you a smooth prediction that passes through them, flanked by uncertainty bands that widen honestly in regions where it has no data. The problem? GPs require inverting an n×n matrix, making them O(n³) — fine for 100 points, painful for 1,000, and impossible for 100,000. They also need you to hand-design a kernel function that encodes your assumptions about the data's structure.
Neural networks sit at the opposite extreme. They handle any dataset size, learn representations automatically, and run inference in milliseconds. But ask a standard neural network "how confident are you?" and it'll shrug. It gives you a point prediction with no honest measure of uncertainty.
What if we could get both? A model that takes in a handful of observed points — a context set — and instantly outputs predictions with calibrated uncertainty at any new location, just like a GP, but using learned representations instead of hand-crafted kernels and running in O(n) instead of O(n³)?
That's exactly what Neural Processes do. Introduced by Garnelo et al. in 2018, they frame prediction as a meta-learning problem: given a context set C = {(xi, yi)} of observed input-output pairs and a new target input x*, predict the distribution p(y* | x*, C). The model learns across thousands of different functions during training, so at test time it can adapt to a completely new function from just a few context points — no retraining, no matrix inversion, no kernel engineering.
2. The Conditional Neural Process (CNP)
The simplest member of the Neural Process family is the Conditional Neural Process (Garnelo et al., 2018). Its architecture has three clean components:
- Encoder: An MLP that maps each context pair (xi, yi) independently to a representation vector ri. Each context point gets its own encoding — the encoder doesn't see the other points.
- Aggregator: Mean pooling across all representations: r = (1/n) ∑ ri. This produces a single fixed-size vector summarizing the entire context set, regardless of how many points it contains.
- Decoder: An MLP that takes the aggregated representation r concatenated with a target input x* and outputs the parameters of a Gaussian distribution: μ* and σ*.
The training loss is the negative log-likelihood of the target points given the context-derived representation:
Loss = ∑t [(yt − μt)² / (2σt²) + log(σt)]
Here's a minimal CNP implementation that trains on sine functions with varying amplitude and phase:
import numpy as np
class CNP:
def __init__(self, h_dim=64):
# Encoder: (x, y) -> r_i
self.W1 = np.random.randn(2, h_dim) * 0.1
self.b1 = np.zeros(h_dim)
self.W2 = np.random.randn(h_dim, h_dim) * 0.1
self.b2 = np.zeros(h_dim)
# Decoder: (r, x*) -> (mu, log_sigma)
self.W3 = np.random.randn(h_dim + 1, h_dim) * 0.1
self.b3 = np.zeros(h_dim)
self.W4 = np.random.randn(h_dim, 2) * 0.1
self.b4 = np.zeros(2)
def encode(self, x_ctx, y_ctx):
"""Encode each context pair, then mean-pool."""
pairs = np.column_stack([x_ctx, y_ctx]) # (n, 2)
h = np.maximum(0, pairs @ self.W1 + self.b1) # ReLU
r_all = np.maximum(0, h @ self.W2 + self.b2) # (n, h_dim)
return r_all.mean(axis=0) # mean pool
def decode(self, r, x_target):
"""Predict mu, sigma at each target x."""
inp = np.column_stack([np.tile(r, (len(x_target), 1)),
x_target]) # (m, h_dim+1)
h = np.maximum(0, inp @ self.W3 + self.b3)
out = h @ self.W4 + self.b4 # (m, 2)
mu = out[:, 0]
sigma = np.exp(out[:, 1]) + 1e-4 # positive via exp
return mu, sigma
def predict(self, x_ctx, y_ctx, x_target):
r = self.encode(x_ctx, y_ctx)
return self.decode(r, x_target)
# Generate a random sine task: y = A*sin(x + phase)
rng = np.random.default_rng(42)
A, phase = rng.uniform(0.5, 2.0), rng.uniform(0, 2 * np.pi)
x_all = rng.uniform(-3, 3, size=(20, 1))
y_all = A * np.sin(x_all + phase)
# Split into 5 context + 15 target
x_ctx, y_ctx = x_all[:5], y_all[:5]
x_tgt, y_tgt = x_all[5:], y_all[5:]
cnp = CNP()
mu, sigma = cnp.predict(x_ctx, y_ctx, x_tgt)
print(f"Predictions at first 3 targets:")
for i in range(3):
print(f" x={x_tgt[i,0]:.2f} true={y_tgt[i,0]:.2f}"
f" pred={mu[i]:.2f} +/- {sigma[i]:.2f}")
The output from an untrained CNP will be random, but the architecture is correct: encode each context pair, mean-pool, decode at target locations. Training adjusts the weights so that the model learns to interpolate through context points with appropriate uncertainty.
3. Why Permutation Invariance Matters
The context set is a set, not a sequence. If someone shows you the points {(0, 1), (2, 3), (4, 2)}, the prediction at x=3 should be the same regardless of whether the points were presented in that order or as {(4, 2), (0, 1), (2, 3)}. This is permutation invariance.
The Deep Sets theorem (Zaheer et al., 2017) tells us that any permutation-invariant function on sets can be decomposed as ρ(∑ φ(xi)) — encode each element independently with φ, sum (or mean) the results, then post-process with ρ. This is exactly the CNP's encoder-aggregator-decoder pattern.
Mean pooling is the simplest permutation-invariant aggregation. It guarantees that shuffling the context produces identical predictions:
def verify_permutation_invariance(cnp, x_ctx, y_ctx, x_target):
"""Show that shuffling context gives identical predictions."""
mu1, sigma1 = cnp.predict(x_ctx, y_ctx, x_target)
# Shuffle context points
perm = np.random.permutation(len(x_ctx))
x_shuf, y_shuf = x_ctx[perm], y_ctx[perm]
mu2, sigma2 = cnp.predict(x_shuf, y_shuf, x_target)
print(f"Original order mu[:3]: {mu1[:3].round(4)}")
print(f"Shuffled order mu[:3]: {mu2[:3].round(4)}")
print(f"Max difference: {np.max(np.abs(mu1 - mu2)):.2e}")
# Max difference: 0.00e+00
verify_permutation_invariance(cnp, x_ctx, y_ctx, x_tgt)
An RNN encoder, by contrast, would produce different hidden states for different orderings — violating the set semantics. This is why the encode-then-aggregate design is fundamental to Neural Processes.
4. The Underfitting Problem
The CNP has a critical limitation that becomes obvious when you look at its predictions closely: the mean prediction often doesn't pass through the context points. Even with perfectly noise-free observations, the CNP's predicted mean can miss the data it was given.
Why? Because the decoder produces each target prediction independently. It gets the same global representation r regardless of which target x* it's predicting at. There's no mechanism forcing the model to be consistent across predictions — and crucially, no mechanism forcing predictions at context locations to match the observed values.
This independence creates two problems:
- No coherent function samples. If you sample from the predicted Gaussian at 100 target points, each sample is drawn independently. The result is jagged noise, not a smooth function. A GP, by contrast, produces correlated samples that look like plausible functions.
- No global uncertainty. If the underlying function is globally shifted upward, all predictions should shift together. But the CNP has no way to represent this — it models each target as a separate, unrelated prediction.
These limitations motivate the latent Neural Process.
5. The Latent Neural Process — Adding Global Uncertainty
The fix is elegant: add a global latent variable z that captures function-level properties. Instead of collapsing all context information into a deterministic representation r, the latent NP encodes context into a distribution over z — a Gaussian with learned mean μz and standard deviation σz. Sampling different values of z produces different plausible functions, all consistent with the context.
The architecture now has two paths:
- Deterministic path: Encoder → mean pool → r (same as CNP)
- Latent path: Encoder → mean pool → (μz, σz) → sample z
The decoder conditions on both: g(r, z, x*) → (μ*, σ*). The latent z carries global information (amplitude, frequency, smoothness) while the deterministic r carries local context details.
Training uses the Evidence Lower Bound (ELBO), the same variational objective used in VAEs:
ELBO = Eq(z|C,T)[log p(yT | xT, z, r)] − KL(q(z|C,T) || q(z|C))
The first term says: "z should help predict the targets." The second term says: "the posterior over z when you see both context AND targets shouldn't be too different from the posterior using context alone" — which encourages the model to extract useful information from context without relying on peeking at targets. We use the reparameterization trick to backpropagate through the sampling step: z = μz + σz ⊙ ε, where ε ~ N(0, I).
class LatentNP:
def __init__(self, h_dim=64, z_dim=16):
self.h_dim, self.z_dim = h_dim, z_dim
# Shared encoder: (x, y) -> hidden
self.enc_W1 = np.random.randn(2, h_dim) * 0.1
self.enc_b1 = np.zeros(h_dim)
# Deterministic path: hidden -> r
self.det_W = np.random.randn(h_dim, h_dim) * 0.1
self.det_b = np.zeros(h_dim)
# Latent path: hidden -> (mu_z, log_sigma_z)
self.lat_W = np.random.randn(h_dim, z_dim * 2) * 0.1
self.lat_b = np.zeros(z_dim * 2)
# Decoder: (r, z, x*) -> (mu, log_sigma)
self.dec_W1 = np.random.randn(h_dim + z_dim + 1, h_dim) * 0.1
self.dec_b1 = np.zeros(h_dim)
self.dec_W2 = np.random.randn(h_dim, 2) * 0.1
self.dec_b2 = np.zeros(2)
def encode_context(self, x_ctx, y_ctx):
pairs = np.column_stack([x_ctx, y_ctx])
h = np.maximum(0, pairs @ self.enc_W1 + self.enc_b1)
# Deterministic representation
r = np.maximum(0, h @ self.det_W + self.det_b).mean(axis=0)
# Latent distribution parameters
lat = (h @ self.lat_W + self.lat_b).mean(axis=0)
mu_z = lat[:self.z_dim]
sigma_z = np.exp(lat[self.z_dim:]) + 1e-4
return r, mu_z, sigma_z
def sample_z(self, mu_z, sigma_z, rng):
"""Reparameterization trick: z = mu + sigma * epsilon."""
eps = rng.standard_normal(self.z_dim)
return mu_z + sigma_z * eps
def decode(self, r, z, x_target):
n = len(x_target)
inp = np.column_stack([np.tile(r, (n, 1)),
np.tile(z, (n, 1)), x_target])
h = np.maximum(0, inp @ self.dec_W1 + self.dec_b1)
out = h @ self.dec_W2 + self.dec_b2
return out[:, 0], np.exp(out[:, 1]) + 1e-4
def predict_samples(self, x_ctx, y_ctx, x_target, n_samples=5):
"""Draw multiple coherent function samples."""
r, mu_z, sigma_z = self.encode_context(x_ctx, y_ctx)
rng = np.random.default_rng(0)
samples = []
for _ in range(n_samples):
z = self.sample_z(mu_z, sigma_z, rng)
mu, _ = self.decode(r, z, x_target)
samples.append(mu)
return np.array(samples) # (n_samples, n_targets)
lnp = LatentNP()
samples = lnp.predict_samples(x_ctx, y_ctx, x_tgt, n_samples=5)
print(f"5 function samples, each at {samples.shape[1]} targets")
print(f"Sample std across functions: {samples.std(axis=0).mean():.3f}")
# Different z values produce different but coherent predictions
The key difference from the CNP: each sample is a coherent function — smooth and globally consistent, not independently sampled noise. Different z values produce functions with different amplitudes, frequencies, or offsets, but each individual sample is a plausible function that respects the context.
6. The Attentive Neural Process
The CNP and latent NP both use mean pooling to summarize context. This means every context point contributes equally to the representation, regardless of the target query location. But when predicting at x* = 5.0, a context point at x = 4.9 should matter far more than one at x = 0.1.
The Attentive Neural Process (Kim et al., 2019) replaces mean pooling with cross-attention: each target query attends to all context points, weighting nearby or relevant ones more heavily. The math follows the standard attention mechanism from transformers:
attention(Q, K, V) = softmax(Q · KT / √dk) · V
Queries come from target inputs, keys and values come from context representations. Each target gets its own query-specific context summary rather than a single global vector.
class AttentiveNP:
def __init__(self, h_dim=64, n_heads=4):
self.h_dim = h_dim
self.d_k = h_dim // n_heads
self.n_heads = n_heads
# Encoder
self.enc_W1 = np.random.randn(2, h_dim) * 0.1
self.enc_b1 = np.zeros(h_dim)
# Attention projections (per-head)
self.W_Q = np.random.randn(1, h_dim) * 0.1 # query from x*
self.W_K = np.random.randn(h_dim, h_dim) * 0.1 # key from r_i
self.W_V = np.random.randn(h_dim, h_dim) * 0.1 # value from r_i
# Decoder
self.dec_W1 = np.random.randn(h_dim + 1, h_dim) * 0.1
self.dec_b1 = np.zeros(h_dim)
self.dec_W2 = np.random.randn(h_dim, 2) * 0.1
self.dec_b2 = np.zeros(2)
def cross_attend(self, x_target, ctx_repr):
"""Each target attends to context representations."""
Q = x_target @ self.W_Q # (m, h_dim)
K = ctx_repr @ self.W_K # (n, h_dim)
V = ctx_repr @ self.W_V # (n, h_dim)
# Scaled dot-product attention
scores = Q @ K.T / np.sqrt(self.d_k) # (m, n)
weights = np.exp(scores - scores.max(axis=1, keepdims=True))
weights /= weights.sum(axis=1, keepdims=True) # softmax
return weights @ V # (m, h_dim)
def predict(self, x_ctx, y_ctx, x_target):
pairs = np.column_stack([x_ctx, y_ctx])
ctx_repr = np.maximum(0, pairs @ self.enc_W1 + self.enc_b1)
attended = self.cross_attend(x_target, ctx_repr) # (m, h_dim)
inp = np.column_stack([attended, x_target])
h = np.maximum(0, inp @ self.dec_W1 + self.dec_b1)
out = h @ self.dec_W2 + self.dec_b2
return out[:, 0], np.exp(out[:, 1]) + 1e-4
anp = AttentiveNP()
mu, sigma = anp.predict(x_ctx, y_ctx, x_tgt)
print(f"ANP predictions at first 3 targets:")
for i in range(3):
print(f" x={x_tgt[i,0]:.2f} pred={mu[i]:.2f} +/- {sigma[i]:.2f}")
The cross-attention mechanism gives the ANP a crucial advantage: predictions near context clusters are sharp and accurate (high attention weights on nearby context), while predictions far from any context point have wide uncertainty (attention is spread diffusely). This is exactly the behavior we want — and exactly what GPs give naturally through their posterior variance.
7. Training Neural Processes — The Meta-Learning Loop
Neural Processes are trained with episodic meta-learning. Each training episode presents the model with a different function — one might be a gentle sine wave, the next a steep quadratic, the next a wiggly polynomial. For each function, we split the observed points into context (what the model sees) and target (what it must predict). Over thousands of episodes, the model learns not to memorize any single function, but to rapidly adapt to new functions from a handful of context points.
def train_cnp(cnp, n_episodes=2000, lr=1e-3):
"""Episodic meta-learning training loop."""
rng = np.random.default_rng(42)
for ep in range(n_episodes):
# Step 1: Sample a random function (sine with random params)
amp = rng.uniform(0.5, 2.0)
phase = rng.uniform(0, 2 * np.pi)
freq = rng.uniform(0.5, 2.0)
f = lambda x: amp * np.sin(freq * x + phase)
# Step 2: Sample points from this function
x_all = rng.uniform(-4, 4, size=(20, 1))
y_all = f(x_all) + rng.normal(0, 0.05, size=x_all.shape)
# Step 3: Random context/target split
n_ctx = rng.integers(3, 15)
idx = rng.permutation(20)
x_ctx, y_ctx = x_all[idx[:n_ctx]], y_all[idx[:n_ctx]]
x_tgt, y_tgt = x_all[idx[n_ctx:]], y_all[idx[n_ctx:]]
# Step 4: Forward pass + NLL loss
mu, sigma = cnp.predict(x_ctx, y_ctx, x_tgt)
nll = ((y_tgt.ravel() - mu)**2 / (2 * sigma**2)
+ np.log(sigma)).mean()
# Step 5: Update (in practice, use autograd + Adam)
# Here we just track the loss for demonstration
if ep % 500 == 0:
print(f"Episode {ep:4d} NLL: {nll:.3f}")
train_cnp(cnp)
The key insight: by varying the number of context points randomly (between 3 and 15 here), the model learns to make reasonable predictions with little data AND to sharpen its predictions when given more data — just like a GP that becomes more confident as it observes more points. The variable-size context also prevents the model from overfitting to any particular context size.
8. NPs vs GPs — When to Use Which
| Feature | GP | CNP | Latent NP | ANP |
|---|---|---|---|---|
| Test-time cost | O(n³) | O(n) | O(n) | O(n·m) |
| Kernels | Hand-designed | Learned | Learned | Learned |
| Uncertainty | Exact | Approximate | Approximate | Approximate |
| Coherent samples | Yes | No | Yes | Yes |
| Passes through context | Yes | Often not | Better | Best |
Here's a side-by-side comparison, using a GP with an RBF kernel alongside a trained NP on the same context set:
def gp_predict(x_ctx, y_ctx, x_test, l=1.0, sigma_n=0.1):
"""Standard GP posterior with RBF kernel."""
def rbf(x1, x2):
return np.exp(-0.5 * ((x1 - x2.T) / l) ** 2)
K = rbf(x_ctx, x_ctx) + sigma_n**2 * np.eye(len(x_ctx))
K_s = rbf(x_ctx, x_test)
K_ss = rbf(x_test, x_test)
L = np.linalg.cholesky(K)
alpha = np.linalg.solve(L.T, np.linalg.solve(L, y_ctx))
mu = K_s.T @ alpha
v = np.linalg.solve(L, K_s)
sigma = np.sqrt(np.diag(K_ss - v.T @ v).clip(0))
return mu.ravel(), sigma
import time
x_test = np.linspace(-4, 4, 200).reshape(-1, 1)
# GP inference (exact, O(n^3))
t0 = time.perf_counter()
gp_mu, gp_sigma = gp_predict(x_ctx, y_ctx, x_test)
gp_time = (time.perf_counter() - t0) * 1000
# NP inference (approximate, O(n))
t0 = time.perf_counter()
np_mu, np_sigma = cnp.predict(x_ctx, y_ctx, x_test)
np_time = (time.perf_counter() - t0) * 1000
print(f"GP: {gp_time:.1f}ms | NP: {np_time:.1f}ms")
print(f"GP uncertainty range: [{gp_sigma.min():.3f}, {gp_sigma.max():.3f}]")
print(f"NP uncertainty range: [{np_sigma.min():.3f}, {np_sigma.max():.3f}]")
Use GPs when you have a small dataset (<1,000 points), know what kernel to use, and need exact uncertainty. Use NPs when you need fast adaptation to new tasks, have complex structured inputs (images, graphs), or need to scale to large context sets. NPs shine in scenarios where you face many related but different prediction tasks — weather at different stations, drug response across patients, completion of partially observed images.
Try It: Neural Process Playground
Click on the canvas to add context points. The Neural Process predicts the function with uncertainty bands. Toggle between CNP and Latent NP to see coherent vs independent samples.
Try It: Context Size Explorer
Drag the slider to add more context points and watch both GP and NP predictions sharpen. Notice the timing difference as context grows.
9. Conclusion
Neural Processes sit at a beautiful intersection: the principled uncertainty of Gaussian Processes meets the scalability and learned representations of neural networks, wrapped in a meta-learning training procedure. The progression tells a clean story — CNPs show that the basic idea works but suffer from independent predictions and underfitting; latent NPs fix coherence by adding a global latent variable; attentive NPs fix underfitting by replacing pooling with cross-attention.
The NP family continues to evolve. Convolutional CNPs (Gordon et al., 2020) add translation equivariance for spatial data. Transformer Neural Processes replace the entire architecture with transformer blocks. Gaussian Neural Processes enforce exact GP behavior in the limit of infinite context. Each variant pushes the boundary of what "learning to be a GP" can mean.
The core insight, though, remains the same: if you train a neural network to predict from context sets across thousands of different functions, it learns to do something remarkable — it learns the process of going from observations to uncertainty-aware predictions. Not a single function, but the mapping itself.
References & Further Reading
- Garnelo et al. — Conditional Neural Processes (ICML 2018) — the original CNP: encoder-aggregator-decoder for set-conditioned prediction
- Garnelo et al. — Neural Processes (ICML Workshop 2018) — adds latent variables for coherent function-level uncertainty
- Kim et al. — Attentive Neural Processes (ICLR 2019) — cross-attention for query-specific context aggregation
- Zaheer et al. — Deep Sets (NeurIPS 2017) — theoretical foundation for permutation-invariant aggregation
- Dubois et al. — Neural Process Family Tutorial — comprehensive guide with code and visualizations
- Gordon et al. — Convolutional Conditional Neural Processes (ICLR 2020) — translation equivariance for spatial data
- Foong et al. — Meta-Learning Stationary Stochastic Process Prediction — analyzing uncertainty calibration in NPs
- Rasmussen & Williams — Gaussian Processes for Machine Learning (2006) — the GP textbook for comparison context