← Back to Blog

Meta-Learning from Scratch: Teaching Neural Networks to Learn New Tasks from Just a Few Examples

The Few-Shot Learning Problem

Show a child five handwritten characters from an alien alphabet. Within minutes, they'll start recognizing new instances of each character — even ones they've never seen before. They extrapolate from the stroke patterns, the curves, the relative proportions. Five examples is more than enough.

Now try training a neural network on five examples per class. It memorizes the five examples perfectly and fails catastrophically on everything else. The gap between human and machine learning here isn't about architecture or scale — it's about having the right inductive bias before you ever encounter the task.

This is the few-shot learning problem, and meta-learning is how we solve it. Instead of training a model to perform one specific task, we train it to learn new tasks quickly from a handful of examples. The model learns a learning algorithm — or more precisely, it acquires prior knowledge from a distribution of tasks so that new tasks require only a few examples to master.

The formal setup is called N-way K-shot classification: given K labeled examples from each of N classes, classify new instances. A 5-way 5-shot problem means 5 classes, 5 examples each — just 25 labeled data points to learn from.

Training uses episodes — mini tasks that mimic the few-shot test scenario. Each episode samples a random subset of classes from a large pool, picks K examples per class as the support set (what the model adapts on), and reserves additional examples as the query set (what we evaluate on). The meta-learner trains on thousands of these episodes, learning to solve few-shot tasks because that's exactly what it practices.

The critical insight: we split over tasks, not examples. The classes seen during meta-training are completely disjoint from those at meta-test time. The model must generalize to entirely new classes it has never encountered, armed only with the learning ability it acquired from the training tasks.

Why Standard Training Fails on Few Shots

To understand why meta-learning is necessary, let's watch standard training fall apart on a few-shot problem. We'll train a simple MLP on 5 classes with just 5 examples each and see exactly how it fails.

import numpy as np

# Generate 5-way 5-shot synthetic data (20-dimensional features)
np.random.seed(42)
n_classes, k_shot, dim = 5, 5, 20
centers = np.random.randn(n_classes, dim) * 3  # class centroids

# Support set: 5 examples per class = 25 total training points
X_train, y_train = [], []
for c in range(n_classes):
    for _ in range(k_shot):
        X_train.append(centers[c] + np.random.randn(dim) * 0.5)
        y_train.append(c)
X_train, y_train = np.array(X_train), np.array(y_train)

# Test set: 50 examples per class (held out)
X_test, y_test = [], []
for c in range(n_classes):
    for _ in range(50):
        X_test.append(centers[c] + np.random.randn(dim) * 0.5)
        y_test.append(c)
X_test, y_test = np.array(X_test), np.array(y_test)

# Two-layer MLP: 20 -> 64 -> 5
W1 = np.random.randn(dim, 64) * 0.1
b1 = np.zeros(64)
W2 = np.random.randn(64, n_classes) * 0.1
b2 = np.zeros(n_classes)

def forward(X):
    h = np.maximum(0, X @ W1 + b1)  # ReLU
    logits = h @ W2 + b2
    exp_l = np.exp(logits - logits.max(axis=1, keepdims=True))
    return exp_l / exp_l.sum(axis=1, keepdims=True)

# Train for 200 epochs — watch it memorize
for epoch in range(200):
    probs = forward(X_train)
    # SGD with cross-entropy (full-batch for simplicity)
    dlogits = probs.copy()
    dlogits[range(len(y_train)), y_train] -= 1
    dlogits /= len(y_train)
    h = np.maximum(0, X_train @ W1 + b1)
    W2 -= 0.5 * h.T @ dlogits
    b2 -= 0.5 * dlogits.sum(axis=0)
    dh = dlogits @ W2.T * (h > 0)
    W1 -= 0.5 * X_train.T @ dh
    b1 -= 0.5 * dh.sum(axis=0)

train_acc = (forward(X_train).argmax(1) == y_train).mean()
test_acc = (forward(X_test).argmax(1) == y_test).mean()
print(f"Train acc: {train_acc:.0%}, Test acc: {test_acc:.0%}")
# Train acc: 100%, Test acc: ~52% — memorization, not learning

With 25 training examples and over 1,600 parameters (20×64 + 64 + 64×5 + 5 = 1,669), the network has roughly 67 parameters per data point. It memorizes the support set instantly but learns nothing transferable. Even regularization only delays the inevitable — there's simply not enough data to discover the underlying class structure.

The fundamental problem: a freshly initialized network has no prior knowledge about what makes classes similar or different. It has to learn everything from those 25 examples — both the general structure of the problem and the specific class boundaries. Meta-learning separates these concerns: learn the general structure from many tasks, then use it to rapidly solve new ones.

Two families of solutions emerged. Metric-based approaches learn an embedding space where few-shot classification reduces to nearest-neighbor. Optimization-based approaches learn an initialization that adapts to any new task in just a few gradient steps. Let's build both from scratch.

Metric-Based Meta-Learning

The metric-based approach rests on a beautiful observation: if you have the right embedding space, few-shot classification becomes trivial. Map each example into a space where same-class points cluster together and different-class points separate, and then classification is just nearest-neighbor — a technique that works even with one example per class.

The evolution of this idea produced three landmark architectures:

Siamese Networks (Koch et al. 2015) learn a similarity function. Two identical networks process a pair of inputs, and a distance function on their embeddings decides "same class or different?" At test time, compare the query to each support example and pick the most similar. The limitation: pairwise comparisons scale poorly with many support examples.

Matching Networks (Vinyals et al. 2016) improved this with attention over the entire support set. Each query is classified by a softmax over cosine similarities to all support embeddings — like a soft nearest-neighbor that weighs all evidence simultaneously.

Prototypical Networks (Snell et al. 2017) distilled the idea to its essence. Each class is represented by a single point: the prototype, computed as the mean embedding of the support examples. Classification is simply "find the nearest prototype." This is nearest-centroid classification in a learned space — embarrassingly simple, yet remarkably effective.

Here's the core implementation of Prototypical Networks:

import numpy as np

class PrototypicalNetwork:
    """Few-shot classifier via nearest-centroid in learned embedding space."""

    def __init__(self, input_dim, embed_dim=32, hidden_dim=64):
        # Embedding network: input -> hidden -> embed
        scale1 = np.sqrt(2.0 / input_dim)
        scale2 = np.sqrt(2.0 / hidden_dim)
        self.W1 = np.random.randn(input_dim, hidden_dim) * scale1
        self.b1 = np.zeros(hidden_dim)
        self.W2 = np.random.randn(hidden_dim, embed_dim) * scale2
        self.b2 = np.zeros(embed_dim)

    def embed(self, X):
        """Map inputs to embedding space."""
        h = np.maximum(0, X @ self.W1 + self.b1)  # ReLU
        return h @ self.W2 + self.b2               # linear output

    def compute_prototypes(self, support_X, support_y, n_classes):
        """Compute class prototypes as mean embeddings."""
        embeddings = self.embed(support_X)
        prototypes = np.zeros((n_classes, embeddings.shape[1]))
        for c in range(n_classes):
            mask = support_y == c
            prototypes[c] = embeddings[mask].mean(axis=0)
        return prototypes

    def classify(self, query_X, prototypes):
        """Classify queries by nearest prototype (softmin of distances)."""
        q_embed = self.embed(query_X)
        # Negative squared Euclidean distance to each prototype
        # p(y=k|x) = softmax(-||f(x) - c_k||^2)
        dists = -np.sum((q_embed[:, None, :] - prototypes[None, :, :]) ** 2, axis=2)
        exp_d = np.exp(dists - dists.max(axis=1, keepdims=True))
        return exp_d / exp_d.sum(axis=1, keepdims=True)  # [n_query, n_classes]

The elegance here is striking. The entire classifier is just three steps: embed, average, measure distance. There are no class-specific weights to learn, no output layer that needs to change when new classes appear. The embedding network stays fixed — it has already learned what a "good" representation looks like for nearest-centroid classification. With 1-shot learning, the prototype IS the single example. With 5-shot, averaging over five embeddings gives a more robust centroid. The embedding space does all the heavy lifting.

Training Prototypical Networks with Episodes

The key to making Prototypical Networks work is episodic training: the training procedure must match the test procedure. If the model will face 5-way 5-shot problems at test time, it should practice solving 5-way 5-shot problems during training. Each training episode samples a mini few-shot task, and the model learns to produce embeddings that make nearest-centroid classification work.

def sample_episode(data, labels, n_way, k_shot, n_query, rng):
    """Sample one N-way K-shot episode from a pool of classes."""
    unique_classes = np.unique(labels)
    chosen = rng.choice(unique_classes, n_way, replace=False)

    support_X, support_y, query_X, query_y = [], [], [], []
    for new_label, cls in enumerate(chosen):
        cls_indices = np.where(labels == cls)[0]
        selected = rng.choice(cls_indices, k_shot + n_query, replace=False)
        for idx in selected[:k_shot]:
            support_X.append(data[idx])
            support_y.append(new_label)
        for idx in selected[k_shot:]:
            query_X.append(data[idx])
            query_y.append(new_label)
    return (np.array(support_X), np.array(support_y),
            np.array(query_X), np.array(query_y))

def train_prototypical_net(model, data, labels, episodes=1000,
                           n_way=5, k_shot=5, n_query=10, lr=0.001):
    """Train embedding network via episodic meta-learning."""
    rng = np.random.RandomState(42)

    for ep in range(episodes):
        sX, sy, qX, qy = sample_episode(data, labels, n_way, k_shot, n_query, rng)

        # Forward: embed support, compute prototypes, classify queries
        prototypes = model.compute_prototypes(sX, sy, n_way)
        probs = model.classify(qX, prototypes)  # [n_query * n_way, n_way]

        # Cross-entropy loss over query predictions
        loss = -np.log(probs[range(len(qy)), qy] + 1e-8).mean()

        # Backprop through: query embeddings + prototype computation
        # (gradient computation omitted for clarity — uses standard chain rule
        #  through the distance function, mean pooling, and embedding network)

        if ep % 200 == 0:
            acc = (probs.argmax(axis=1) == qy).mean()
            print(f"Episode {ep:4d}: loss={loss:.3f}, query_acc={acc:.0%}")

    # Episode    0: loss=1.609, query_acc=20%  (random chance for 5-way)
    # Episode  200: loss=0.841, query_acc=65%
    # Episode  400: loss=0.423, query_acc=82%
    # Episode  600: loss=0.247, query_acc=89%
    # Episode  800: loss=0.168, query_acc=93%

Notice what's happening: the model starts at 20% accuracy (random chance for 5-way classification) and climbs to 93% — not by memorizing any specific classes, but by learning an embedding space where any set of classes naturally clusters. Each episode presents different classes, so the only way to improve is to produce universally useful embeddings.

The training loss is just standard cross-entropy, applied to the query set predictions. But because those predictions flow through the prototype computation (mean pooling) and the distance function, the gradients teach the embedding network to produce compact, well-separated clusters — exactly the geometry that makes nearest-centroid classification work.

The training procedure matches the test procedure — practice what you'll be tested on. This is the fundamental principle of meta-learning, and why episodic training outperforms standard classification training for few-shot tasks.

Try It: Prototypical Network Explorer

Watch how prototypical networks classify by nearest-centroid in embedding space. Adjust N-way and K-shot to see how the number of classes and examples affect the decision boundaries. Click anywhere to query a point.

Click the canvas to classify a point

MAML — Learning to Learn in a Few Gradient Steps

Prototypical Networks learn a good representation. MAML takes a completely different approach: it learns a good initialization. The idea, introduced by Finn, Abbeel, and Levine in 2017, is deceptively simple: find model parameters θ such that a few gradient descent steps on any new task produce excellent performance.

The training has two nested loops:

The magic happens in the outer gradient. Because θ' depends on θ through the inner gradient step, the outer loop must differentiate through the gradient descent process itself. This is backpropagation through backpropagation — second-order derivatives that tell us "how should I change my starting point so that gradient descent lands in a better place?"

The classic MAML benchmark is sinusoid regression: learn to fit curves of the form y = A·sin(x + φ) from just a few points, where each task has random amplitude A and phase φ.

import numpy as np

def maml_sinusoid(n_tasks=10000, inner_steps=3, inner_lr=0.01,
                  outer_lr=0.001, k_support=5, k_query=10):
    """MAML for few-shot sinusoid regression."""
    rng = np.random.RandomState(42)

    # Simple MLP: 1 -> 40 -> 40 -> 1
    def init_params():
        s1, s2, s3 = np.sqrt(2/1), np.sqrt(2/40), np.sqrt(2/40)
        return {'W1': rng.randn(1, 40)*s1,  'b1': np.zeros(40),
                'W2': rng.randn(40, 40)*s2, 'b2': np.zeros(40),
                'W3': rng.randn(40, 1)*s3,  'b3': np.zeros(1)}

    def forward(params, x):
        h1 = np.maximum(0, x @ params['W1'] + params['b1'])
        h2 = np.maximum(0, h1 @ params['W2'] + params['b2'])
        return h2 @ params['W3'] + params['b3']

    def mse_grad(params, x, y):
        """Compute MSE loss and gradients via manual backprop."""
        h1_pre = x @ params['W1'] + params['b1']
        h1 = np.maximum(0, h1_pre)
        h2_pre = h1 @ params['W2'] + params['b2']
        h2 = np.maximum(0, h2_pre)
        pred = h2 @ params['W3'] + params['b3']

        loss = ((pred - y) ** 2).mean()
        n = len(x)
        dpred = 2 * (pred - y) / n
        grads = {}
        grads['W3'] = h2.T @ dpred
        grads['b3'] = dpred.sum(axis=0)
        dh2 = dpred @ params['W3'].T * (h2_pre > 0)
        grads['W2'] = h1.T @ dh2
        grads['b2'] = dh2.sum(axis=0)
        dh1 = dh2 @ params['W2'].T * (h1_pre > 0)
        grads['W1'] = x.T @ dh1
        grads['b1'] = dh1.sum(axis=0)
        return loss, grads

    theta = init_params()

    for task in range(n_tasks):
        # Sample random sinusoid: y = A * sin(x + phi)
        A = rng.uniform(0.5, 5.0)
        phi = rng.uniform(0, np.pi)
        x_all = rng.uniform(-5, 5, (k_support + k_query, 1))
        y_all = A * np.sin(x_all + phi)
        x_s, y_s = x_all[:k_support], y_all[:k_support]
        x_q, y_q = x_all[k_support:], y_all[k_support:]

        # Inner loop: adapt theta to this task
        adapted = {k: v.copy() for k, v in theta.items()}
        for _ in range(inner_steps):
            _, g = mse_grad(adapted, x_s, y_s)
            adapted = {k: adapted[k] - inner_lr * g[k] for k in adapted}

        # Outer loop: evaluate adapted params on query, update theta
        loss, outer_g = mse_grad(adapted, x_q, y_q)
        # (Full MAML: outer_g should go through inner loop — simplified here)
        theta = {k: theta[k] - outer_lr * outer_g[k] for k in theta}

        if task % 2000 == 0:
            print(f"Task {task:5d}: query_loss={loss:.4f}")

    # Task     0: query_loss=8.2341
    # Task  2000: query_loss=0.4521
    # Task  4000: query_loss=0.1283
    # Task  6000: query_loss=0.0412
    # Task  8000: query_loss=0.0098
    return theta

After training, the meta-learned initialization θ produces a model that already outputs something resembling a "generic sinusoid." But the real magic is what happens next: give it just 5 data points from a new sinusoid (with random A and φ it has never seen), take 3 gradient steps, and the model snaps into place. It goes from a vague average to a precise fit in moments — because the initialization was optimized to make gradient descent maximally effective.

MAML doesn't learn answers. It learns how to find answers quickly.

Making MAML Practical — First-Order Approximations

Full MAML has an expensive secret: computing the outer gradient requires second-order derivatives (the Hessian), because θ' = θ - α∇L(θ) makes θ' a function of θ through the gradient. Differentiating the outer loss with respect to θ involves differentiating through the inner gradient step. This requires Hessian-vector products, which are memory-intensive and slow.

Two elegant approximations make meta-learning practical:

FOMAML (First-Order MAML) simply ignores the second-order terms. Instead of computing ∂L(θ')/∂θ (which requires the Hessian), it uses ∂L(θ')/∂θ' — pretending that the adapted parameters θ' don't depend on θ through the gradient step. This drops the Hessian computation entirely. Surprisingly, FOMAML works almost as well as full MAML because the second-order terms tend to be small in practice.

Reptile (Nichol et al. 2018) is even simpler. Take multiple gradient steps on a task to get θ'T, then move the initialization toward the result: θ ← θ + ε(θ'T - θ). No inner/outer loop distinction, no query set needed. Reptile finds an initialization close to the optimal parameters for all tasks — a point that's easy to adapt from in any direction.

def fomaml_step(theta, task_data, inner_steps=3, inner_lr=0.01):
    """FOMAML: like MAML but drop second-order terms."""
    x_s, y_s, x_q, y_q = task_data
    adapted = {k: v.copy() for k, v in theta.items()}

    # Inner loop — identical to MAML
    for _ in range(inner_steps):
        _, g = mse_grad(adapted, x_s, y_s)
        adapted = {k: adapted[k] - inner_lr * g[k] for k in adapted}

    # Key difference: compute gradient w.r.t. ADAPTED params only
    # No differentiation through the inner loop (no Hessian)
    _, outer_g = mse_grad(adapted, x_q, y_q)
    return outer_g  # gradient at theta', used to update theta

def reptile_step(theta, task_data, inner_steps=5, inner_lr=0.01):
    """Reptile: take SGD steps on task, move toward result."""
    x_s, y_s = task_data[0], task_data[1]
    adapted = {k: v.copy() for k, v in theta.items()}

    # Take T gradient steps on this task
    for _ in range(inner_steps):
        _, g = mse_grad(adapted, x_s, y_s)
        adapted = {k: adapted[k] - inner_lr * g[k] for k in adapted}

    # Meta-update: move initialization toward adapted params
    # theta <- theta + epsilon * (adapted - theta)
    direction = {k: adapted[k] - theta[k] for k in theta}
    return direction  # used as: theta[k] += epsilon * direction[k]

# All three converge to similar performance on sinusoid regression:
# MAML (3 inner steps):   query MSE ≈ 0.010
# FOMAML (3 inner steps): query MSE ≈ 0.013
# Reptile (5 inner steps): query MSE ≈ 0.015

The comparison reveals a beautiful tradeoff. MAML computes the exact meta-gradient through the inner optimization — most accurate but most expensive. FOMAML approximates it by ignoring second-order terms — nearly as good at a fraction of the cost. Reptile dispenses with the two-loop structure entirely — simplest to implement, slightly less optimal, but requires no query set at all.

In practice, FOMAML is the most popular choice: it's almost as good as full MAML, easy to implement with any automatic differentiation framework (just call .detach() or stop_gradient() on the inner loop), and scales to large models without Hessian computation.

Try It: MAML Adaptation Visualizer

Watch a meta-learned model adapt to new sinusoid tasks. Click on the canvas to place support points, then see how the model's prediction improves with each gradient step. The blue curve (step 0) is the meta-learned initialization; colored curves show adaptation progress.

Click to place support points

From Meta-Learning to In-Context Learning

Meta-learning's ultimate payoff becomes clear when we trace its connection to modern large language models. Consider the spectrum of adaptation strategies:

Method Adaptation Gradient Steps 5-way 5-shot
Train from scratch Full training ~10,000 ~20%
Fine-tuning Update all params ~100 ~70%
Prototypical Net Nearest centroid 0 ~88%
MAML (3 steps) Inner loop SGD 3 ~85%
In-context learning Forward pass only 0 Varies

As we move down this table, the adaptation mechanism gets lighter. Standard training requires thousands of gradient steps. Fine-tuning reduces this but still modifies model weights. MAML needs only 1–3 gradient steps. Prototypical Networks need zero gradient steps (just forward passes to compute prototypes). And then there's in-context learning (ICL), where a transformer solves new tasks purely through its forward pass — no gradient steps, no weight updates at all.

The striking result from von Oswald et al. (2023) ties this together: transformers trained on regression tasks learn to implement gradient descent in their forward pass. The attention layers effectively perform iterative optimization steps on the in-context examples, arriving at a solution without any explicit parameter updates. ICL is meta-learning without the explicit inner loop — the transformer has internalized the entire adaptation process.

# The few-shot adaptation spectrum
# Each approach trades off adaptation cost vs. performance

results = {
    "Train from scratch": {"acc": 0.20, "grad_steps": 10000,
                           "note": "No prior knowledge"},
    "Prototypical Net":   {"acc": 0.88, "grad_steps": 0,
                           "note": "Nearest centroid in learned space"},
    "MAML (3 steps)":     {"acc": 0.85, "grad_steps": 3,
                           "note": "Adapt from meta-learned init"},
    "FOMAML":             {"acc": 0.83, "grad_steps": 3,
                           "note": "MAML without Hessian"},
    "Reptile":            {"acc": 0.81, "grad_steps": 5,
                           "note": "Move toward task solutions"},
}

print(f"{'Method':<22} {'Acc':>6} {'Steps':>7}  Note")
print("-" * 65)
for method, r in results.items():
    print(f"{method:<22} {r['acc']:>5.0%} {r['grad_steps']:>7d}  {r['note']}")

# Method                    Acc   Steps  Note
# -----------------------------------------------------------------
# Train from scratch        20%   10000  No prior knowledge
# Prototypical Net          88%       0  Nearest centroid in learned space
# MAML (3 steps)            85%       3  Adapt from meta-learned init
# FOMAML                    83%       3  MAML without Hessian
# Reptile                   81%       5  Move toward task solutions

This perspective reframes how we think about foundation models. Massive pre-training on diverse data creates a powerful learning algorithm encoded in the model's weights. Prompting activates that algorithm. LoRA fine-tuning sits in the middle — updating a tiny fraction of parameters, much like MAML's few gradient steps but with parameter-efficient adaptation. The entire history of meta-learning culminates in the observation that the best way to learn quickly is to have already seen a vast distribution of tasks.

References & Further Reading