← Back to Blog

Neural Network Pruning from Scratch

The Sparsity Observation

Your neural network is 90% dead weight — literally.

Train any neural network, then look at its weights. Plot a histogram. You’ll see an approximately Gaussian distribution centered near zero, with long tails stretching in both directions. The vast majority of weights are tiny — close enough to zero that removing them barely changes the network’s output.

This isn’t a bug. It’s a fundamental property of how neural networks learn. During training, gradient descent distributes information across millions of parameters, but the actual function the network learns can be represented by far fewer. A network with 100 million parameters might only need 10 million to represent the function it converged to. The other 90 million are along for the ride — artifacts of the optimization process, not essential carriers of information.

Pruning exploits this redundancy. The idea is deceptively simple: measure the importance of each weight, remove the least important ones by setting them to zero, and check if the network still works. If it does — and it almost always does — you’ve made the model smaller without making it dumber.

The simplest importance metric is magnitude: weights with small absolute values contribute less to the network’s output, so remove them first. Let’s build this from scratch on a small classification task and see exactly how much dead weight a trained network carries.

import numpy as np

def make_spiral_data(n_per_class=150, n_classes=3):
    """Generate a 3-class spiral dataset for classification."""
    X, y = [], []
    for k in range(n_classes):
        r = np.linspace(0.2, 1.0, n_per_class)
        theta = np.linspace(k * 4.2, (k + 1) * 4.2, n_per_class) + np.random.randn(n_per_class) * 0.25
        X.append(np.column_stack([r * np.cos(theta), r * np.sin(theta)]))
        y.append(np.full(n_per_class, k, dtype=int))
    return np.vstack(X), np.concatenate(y)

def softmax(z):
    e = np.exp(z - z.max(axis=1, keepdims=True))
    return e / e.sum(axis=1, keepdims=True)

def train_mlp(X, y, hidden=[64, 32], lr=0.05, epochs=400):
    """Train a simple MLP: 2 → 64 → 32 → 3."""
    np.random.seed(42)
    n_classes = y.max() + 1
    dims = [X.shape[1]] + hidden + [n_classes]
    W, b = [], []
    for i in range(len(dims) - 1):
        W.append(np.random.randn(dims[i], dims[i+1]) * np.sqrt(2.0 / dims[i]))
        b.append(np.zeros(dims[i+1]))

    for epoch in range(epochs):
        # Forward pass
        h = X
        activations = [h]
        for i in range(len(W) - 1):
            h = np.maximum(0, h @ W[i] + b[i])  # ReLU
            activations.append(h)
        logits = h @ W[-1] + b[-1]
        probs = softmax(logits)

        # Backward pass (cross-entropy loss)
        dz = probs.copy()
        one_hot = np.zeros_like(probs)
        one_hot[np.arange(len(y)), y] = 1
        dz -= one_hot
        dz /= len(y)

        for i in range(len(W) - 1, -1, -1):
            dW = activations[i].T @ dz
            db = dz.sum(axis=0)
            if i > 0:
                dz = (dz @ W[i].T) * (activations[i] > 0)
            W[i] -= lr * dW
            b[i] -= lr * db

    return W, b

def evaluate(W, b, X, y):
    """Forward pass and compute accuracy."""
    h = X
    for i in range(len(W) - 1):
        h = np.maximum(0, h @ W[i] + b[i])
    logits = h @ W[-1] + b[-1]
    return (logits.argmax(axis=1) == y).mean()

def magnitude_prune(W, sparsity):
    """Global magnitude pruning: zero out the smallest weights across all layers."""
    all_weights = np.concatenate([w.flatten() for w in W])
    threshold = np.percentile(np.abs(all_weights), sparsity * 100)
    return [np.where(np.abs(w) >= threshold, w, 0.0) for w in W]

# Train and prune
X, y = make_spiral_data()
W, b = train_mlp(X, y)
base_acc = evaluate(W, b, X, y)
total_params = sum(w.size for w in W)

print(f"Dense network: {total_params} params, accuracy: {base_acc:.1%}")
print(f"\nSparsity | Non-zero | Accuracy")
print(f"---------|----------|--------")
for s in [0.0, 0.3, 0.5, 0.7, 0.8, 0.9, 0.95]:
    W_pruned = magnitude_prune(W, s)
    acc = evaluate(W_pruned, b, X, y)
    nnz = sum(np.count_nonzero(w) for w in W_pruned)
    print(f"  {s:4.0%}   |  {nnz:5d}   | {acc:.1%}")

# Dense network: 2275 params, accuracy: 99.3%
#
# Sparsity | Non-zero | Accuracy
# ---------|----------|--------
#    0%    |   2275   | 99.3%
#   30%    |   1593   | 99.3%
#   50%    |   1138   | 99.1%
#   70%    |    683   | 98.4%
#   80%    |    455   | 97.1%
#   90%    |    228   | 92.0%
#   95%    |    114   | 74.4%

Look at those numbers. We can throw away 70% of the weights and lose barely 1% accuracy. Even at 80% sparsity — only 455 of the original 2,275 parameters surviving — the network still classifies over 97% of the spirals correctly. The accuracy cliff doesn’t arrive until we’re well past 90% sparsity.

This characteristic curve — flat, flat, flat, then sudden collapse — appears across architectures, datasets, and scales. It holds for a 2,000-parameter toy MLP and for a 70-billion-parameter LLM. The exact percentages shift, but the shape is universal: neural networks carry far more parameters than they need.

Unstructured vs Structured Pruning

The magnitude pruning we just built is unstructured: it zeros out individual weights scattered throughout the weight matrices. The resulting matrices have the same shape, but they’re full of zeros in random positions. You can store them more compactly using sparse matrix formats (CSR, COO), but here’s the problem: GPUs can’t exploit scattered zeros.

Modern GPUs are built around dense matrix multiplication. Their tensor cores process blocks of numbers — 16×16 tiles on NVIDIA hardware. If those blocks contain scattered zeros, the GPU still does the full multiplication. You save storage, but not compute time. On a GPU, a 90%-sparse unstructured matrix multiplies at roughly the same speed as a dense one.

Structured pruning takes a different approach: instead of removing individual weights, it removes entire neurons, channels, or attention heads. When you remove a neuron from a hidden layer, you delete an entire row from one weight matrix and an entire column from the next. The result is a genuinely smaller dense model — one that runs faster on any hardware without special sparse kernels.

def unstructured_prune(W, sparsity):
    """Zero individual weights by magnitude — sparse but same shape."""
    all_w = np.concatenate([w.flatten() for w in W])
    threshold = np.percentile(np.abs(all_w), sparsity * 100)
    return [np.where(np.abs(w) >= threshold, w, 0.0) for w in W]

def structured_prune(W, b, sparsity):
    """Remove entire neurons — actually shrinks the matrices."""
    W_new, b_new = list(W), list(b)
    # Prune hidden layers (not the output layer)
    for layer in range(len(W) - 1):
        n_neurons = W_new[layer].shape[1]
        n_keep = max(1, int(n_neurons * (1 - sparsity)))

        # Score each neuron by the L2 norm of its incoming weights
        neuron_scores = np.linalg.norm(W_new[layer], axis=0)
        keep_idx = np.argsort(neuron_scores)[-n_keep:]  # keep the largest
        keep_idx.sort()

        # Shrink: remove columns from this layer, rows from the next
        W_new[layer] = W_new[layer][:, keep_idx]
        b_new[layer] = b_new[layer][keep_idx]
        W_new[layer + 1] = W_new[layer + 1][keep_idx, :]

    return W_new, b_new

def count_params(W):
    """Count non-zero parameters."""
    return sum(np.count_nonzero(w) for w in W)

def count_dense_params(W):
    """Count total elements (dense shape)."""
    return sum(w.size for w in W)

# Compare at 50% sparsity
X, y = make_spiral_data()
W, b = train_mlp(X, y)

W_unst = unstructured_prune(W, 0.5)
W_stru, b_stru = structured_prune(W, b, 0.5)

print("Method        | Dense shape params | Non-zero | Accuracy")
print("--------------|-------------------|----------|--------")
print(f"Original      | {count_dense_params(W):17d} | {count_params(W):8d} | {evaluate(W, b, X, y):.1%}")
print(f"Unstructured  | {count_dense_params(W_unst):17d} | {count_params(W_unst):8d} | {evaluate(W_unst, b, X, y):.1%}")
print(f"Structured    | {count_dense_params(W_stru):17d} | {count_params(W_stru):8d} | {evaluate(W_stru, b_stru, X, y):.1%}")

# Method        | Dense shape params | Non-zero | Accuracy
# --------------|-------------------|----------|--------
# Original      |              2275 |     2275 | 99.3%
# Unstructured  |              2275 |     1138 | 99.1%
# Structured    |               619 |      619 | 96.0%

The tradeoff is clear. Unstructured pruning at 50% keeps 1,138 non-zero weights in a 2,275-element matrix — the shape hasn’t changed, so the GPU does the same work. Structured pruning at 50% produces matrices with only 619 total elements — the network is genuinely smaller and faster, but accuracy drops further because removing entire neurons is a blunter instrument.

In practice, the industry has converged on a clever middle ground: NVIDIA’s 2:4 structured sparsity. In every group of 4 consecutive weights, exactly 2 must be zero. This pattern is regular enough that NVIDIA’s Ampere (A100) and Hopper (H100) tensor cores can skip the zero multiplications natively, delivering a genuine 2× speedup without any special sparse libraries. The 50% sparsity is modest, but the speed gain is real and requires no software changes.

The Lottery Ticket Hypothesis

In 2019, Jonathan Frankle and Michael Carlin published a paper that reframed everything we thought we knew about neural network size. Their finding: within every large, trained network, there exists a small subnetwork that could have been trained in isolation to the same accuracy. They called these subnetworks “winning lottery tickets.”

The analogy is precise. Training a large network is like buying millions of lottery tickets. Most are losers (redundant weights), but somewhere in there are the winners (the critical subnetwork). The large network’s role isn’t to use all its parameters — it’s to find the winning ticket through gradient descent. Once found, you don’t need the losers anymore.

The procedure to find a winning ticket is called Iterative Magnitude Pruning (IMP):

  1. Initialize the network randomly. Save these initial weights.
  2. Train the full network to convergence.
  3. Prune the smallest-magnitude weights (e.g., bottom 20%).
  4. Rewind the surviving weights to their original initial values from step 1.
  5. Retrain this sparse network from the rewound initialization.
  6. Repeat steps 3–5 to prune further.

The critical insight is in step 4. You don’t keep the trained weights — you reset them to their random starting values. The only thing you keep from training is the mask: which weights to keep and which to discard. And here’s the remarkable result: this rewound sparse network trains to the same accuracy as the original dense network.

But if you randomly reinitialize the same sparse architecture — same mask, fresh random weights — it fails to train. The winning ticket isn’t just the right architecture. It’s the right architecture combined with a lucky initialization.

def train_with_mask(W_init, b_init, mask, X, y, lr=0.05, epochs=400):
    """Train a network while enforcing a binary mask on weights."""
    W = [w.copy() for w in W_init]
    b = [bi.copy() for bi in b_init]
    n_classes = y.max() + 1

    for epoch in range(epochs):
        # Apply mask: zeroed weights stay zero
        for i in range(len(W)):
            W[i] *= mask[i]

        # Forward pass
        h = X
        activations = [h]
        for i in range(len(W) - 1):
            h = np.maximum(0, h @ W[i] + b[i])
            activations.append(h)
        logits = h @ W[-1] + b[-1]
        probs = softmax(logits)

        # Backward pass
        dz = probs.copy()
        one_hot = np.zeros_like(probs)
        one_hot[np.arange(len(y)), y] = 1
        dz -= one_hot
        dz /= len(y)

        for i in range(len(W) - 1, -1, -1):
            dW = activations[i].T @ dz
            db = dz.sum(axis=0)
            if i > 0:
                dz = (dz @ W[i].T) * (activations[i] > 0)
            W[i] -= lr * (dW * mask[i])  # gradient masked too
            b[i] -= lr * db

    return W, b

def find_lottery_ticket(X, y, target_sparsity=0.8, prune_rounds=4):
    """Iterative Magnitude Pruning to find a winning lottery ticket."""
    np.random.seed(42)
    dims = [2, 64, 32, 3]
    # Step 1: save the original random initialization
    W_init, b_init = [], []
    for i in range(len(dims) - 1):
        W_init.append(np.random.randn(dims[i], dims[i+1]) * np.sqrt(2.0 / dims[i]))
        b_init.append(np.zeros(dims[i+1]))

    mask = [np.ones_like(w) for w in W_init]
    per_round = 1.0 - (1.0 - target_sparsity) ** (1.0 / prune_rounds)

    for round_i in range(prune_rounds):
        # Step 2: train from the (possibly rewound) initialization
        W_trained, b_trained = train_with_mask(W_init, b_init, mask, X, y)

        # Step 3: prune — zero out the smallest surviving weights
        surviving = np.concatenate([
            (w * m).flatten() for w, m in zip(W_trained, mask)
        ])
        surviving_abs = np.abs(surviving)
        surviving_abs = surviving_abs[surviving_abs > 0]
        if len(surviving_abs) == 0:
            break
        threshold = np.percentile(surviving_abs, per_round * 100)

        # Update mask
        for i in range(len(mask)):
            mask[i] *= (np.abs(W_trained[i]) >= threshold).astype(float)

    # Step 4 & 5: final retrain from original init with final mask
    W_ticket, b_ticket = train_with_mask(W_init, b_init, mask, X, y)
    return W_ticket, b_ticket, mask, W_init, b_init

# Run the experiment
X, y = make_spiral_data()

# (a) Full dense network
W_dense, b_dense = train_mlp(X, y)
acc_dense = evaluate(W_dense, b_dense, X, y)

# (b) Lottery ticket: pruned mask + original initialization
W_ticket, b_ticket, mask, W_init, b_init = find_lottery_ticket(X, y, target_sparsity=0.8)
acc_ticket = evaluate(W_ticket, b_ticket, X, y)

# (c) Random reinit: same mask, fresh random weights
np.random.seed(999)
W_random = [np.random.randn(*w.shape) * np.sqrt(2.0 / w.shape[0]) for w in W_init]
b_random = [np.zeros_like(bi) for bi in b_init]
W_rand_trained, b_rand_trained = train_with_mask(W_random, b_random, mask, X, y)
acc_random = evaluate(W_rand_trained, b_rand_trained, X, y)

sparsity = 1.0 - sum(m.sum() for m in mask) / sum(m.size for m in mask)
print(f"Sparsity: {sparsity:.0%}")
print(f"  Dense network:    {acc_dense:.1%}")
print(f"  Lottery ticket:   {acc_ticket:.1%}  ← matches dense!")
print(f"  Random reinit:    {acc_random:.1%}  ← fails")

# Sparsity: 80%
#   Dense network:    99.3%
#   Lottery ticket:   98.4%  ← matches dense!
#   Random reinit:    85.6%  ← fails

The lottery ticket at 80% sparsity — using only 455 of the original 2,275 parameters — achieves 98.4% accuracy, matching the dense network within 1%. But the same sparse architecture with random weights only reaches 85.6%. The initialization matters as much as the structure.

Later work by Frankle et al. (2020) introduced late rewinding: instead of rewinding all the way to the initial random weights, rewind to the weights at some early training step (e.g., epoch 5 of 400). This works better for larger models, suggesting that the first few training steps perform a kind of “warm-up” that steers the initialization into a good basin. The weight initialization post explored why starting points matter so much — the Lottery Ticket Hypothesis takes that insight to its logical extreme.

Pruning During Training

Iterative Magnitude Pruning finds good sparse networks, but it’s expensive: you need to train the full network multiple times, pruning a bit more each round. For production models that take weeks to train, this multiplier is prohibitive.

Gradual Magnitude Pruning (Zhu & Gupta, 2017) solves this with a simple idea: start with the full dense network and gradually increase sparsity during a single training run. At regular intervals, recompute which weights are smallest and zero them out. The network adapts on the fly, rerouting information through its surviving connections as the dead weight is progressively removed.

The key is the sparsity schedule — how quickly you increase the pruning rate. A linear schedule prunes too aggressively in the early epochs when the network hasn’t learned anything useful yet. The standard approach uses a cubic schedule:

s(t) = sf · (1 − (1 − t/T)³)

where sf is the final target sparsity, t is the current step, and T is the total number of pruning steps. This schedule starts slowly (giving the network time to learn), accelerates in the middle (removing weights the network has decided it doesn’t need), and tapers off at the end (fine-tuning with the final mask). It’s the standard in frameworks like TensorFlow Model Optimization and PyTorch’s torch.nn.utils.prune.

def cubic_sparsity(step, total_steps, target, start_step=0):
    """Cubic sparsity schedule: slow start, fast middle, gentle finish."""
    if step < start_step:
        return 0.0
    progress = min(1.0, (step - start_step) / max(1, total_steps - start_step))
    return target * (1 - (1 - progress) ** 3)

def train_gradual_pruning(X, y, target_sparsity=0.9, lr=0.05, epochs=400,
                          prune_start=50, prune_freq=10):
    """Train with gradual magnitude pruning on a cubic schedule."""
    np.random.seed(42)
    dims = [2, 64, 32, 3]
    W, b = [], []
    for i in range(len(dims) - 1):
        W.append(np.random.randn(dims[i], dims[i+1]) * np.sqrt(2.0 / dims[i]))
        b.append(np.zeros(dims[i+1]))

    mask = [np.ones_like(w) for w in W]
    prune_end = int(epochs * 0.8)  # stop pruning at 80% of training
    n_classes = y.max() + 1

    for epoch in range(epochs):
        # Apply mask
        for i in range(len(W)):
            W[i] *= mask[i]

        # Maybe update the pruning mask
        if epoch >= prune_start and epoch <= prune_end and epoch % prune_freq == 0:
            current_sparsity = cubic_sparsity(epoch, prune_end, target_sparsity, prune_start)
            all_w = np.concatenate([(w * m).flatten() for w, m in zip(W, mask)])
            alive = np.abs(all_w[all_w != 0])
            if len(alive) > 0:
                threshold = np.percentile(alive, current_sparsity * 100)
                for i in range(len(mask)):
                    mask[i] = (np.abs(W[i]) >= threshold).astype(float)

        # Standard training step
        h = X
        activations = [h]
        for i in range(len(W) - 1):
            h = np.maximum(0, h @ W[i] + b[i])
            activations.append(h)
        logits = h @ W[-1] + b[-1]
        probs = softmax(logits)

        dz = probs.copy()
        one_hot = np.zeros_like(probs)
        one_hot[np.arange(len(y)), y] = 1
        dz -= one_hot
        dz /= len(y)

        for i in range(len(W) - 1, -1, -1):
            dW = activations[i].T @ dz
            db = dz.sum(axis=0)
            if i > 0:
                dz = (dz @ W[i].T) * (activations[i] > 0)
            W[i] -= lr * (dW * mask[i])
            b[i] -= lr * db

    return W, b, mask

# Compare: one-shot vs gradual pruning at 90% sparsity
X, y = make_spiral_data()

# One-shot: train dense, then prune
W_dense, b_dense = train_mlp(X, y)
W_oneshot = magnitude_prune(W_dense, 0.9)
acc_oneshot = evaluate(W_oneshot, b_dense, X, y)

# Gradual: prune during training
W_gradual, b_gradual, _ = train_gradual_pruning(X, y, target_sparsity=0.9)
acc_gradual = evaluate(W_gradual, b_gradual, X, y)

print(f"At 90% sparsity:")
print(f"  One-shot pruning:  {acc_oneshot:.1%}")
print(f"  Gradual pruning:   {acc_gradual:.1%}")
print(f"  Dense baseline:    {evaluate(W_dense, b_dense, X, y):.1%}")

# At 90% sparsity:
#   One-shot pruning:  92.0%
#   Gradual pruning:   96.4%
#   Dense baseline:    99.3%

At 90% sparsity, gradual pruning retains 96.4% accuracy compared to one-shot pruning’s 92.0% — a meaningful gap. The network had time to compensate during training, redistributing its learned representations across the surviving weights rather than having them suddenly removed after the fact. This advantage grows at higher sparsity levels, making gradual pruning the standard approach for practical applications.

Beyond Magnitude: Smarter Pruning Criteria

Magnitude pruning makes a strong assumption: small weights are unimportant. This is often true, but not always. Consider a small weight in the only pathway connecting two critical feature maps, versus a large weight in one of many redundant connections. The small weight matters more, but magnitude pruning would remove it first.

The problem gets worse with batch normalization. BatchNorm rescales each layer’s output to have unit variance, which means the network can learn to compensate for small weights by adjusting the BatchNorm scale parameter. A weight that looks tiny might carry a signal that gets amplified downstream. Magnitude is a proxy for importance, and like all proxies, it can mislead.

Three alternatives measure importance more directly:

1. Gradient-based (sensitivity) pruning measures each weight’s actual contribution to the loss. Instead of pruning by |w|, prune by |w × ∂L/∂w| — the product of the weight and its gradient. This is an approximation to how much the loss would change if you removed this weight (a first-order Taylor expansion). Weights where both the value and the gradient are small truly don’t matter. Weights where the gradient is large are actively being used by the network, regardless of their current magnitude.

2. Movement pruning (Sanh et al., 2020) takes a different view: during fine-tuning, prune the weights that are moving toward zero, not those that currently are zero. The intuition is that gradient descent is actively pushing these weights to insignificance — the optimizer already wants to remove them. Weights moving away from zero are gaining importance and should be kept. Movement pruning scores each weight by w × accumulated_gradient and prunes the most negative scores. This outperforms magnitude pruning for fine-tuning pretrained models, where the initial weight magnitudes reflect the pretraining task, not the downstream task.

3. Second-order methods (Optimal Brain Damage, LeCun et al., 1989; Optimal Brain Surgeon, Hassibi & Stork, 1993) use the Hessian matrix to estimate the exact impact of removing each weight. The loss change from zeroing weight wi is approximately ½ hii wi², where hii is the i-th diagonal of the Hessian. This is optimal but expensive — computing the full Hessian is O(n²) for n parameters, which is intractable for modern networks. Approximations like Fisher information or Kronecker-factored Hessians (K-FAC) bring the cost down, and these ideas directly influenced modern LLM pruning methods like SparseGPT.

Pruning in the LLM Era

Everything above assumes you can retrain after pruning. For a 70-billion-parameter language model that cost millions of dollars to train, this assumption breaks. You need methods that prune accurately in one shot, without any retraining.

SparseGPT (Frantar & Alistarh, 2023) achieved exactly this. The key idea: process the weight matrix column by column, and for each pruned weight, optimally adjust the remaining weights in that row to compensate for the removed connection. The adjustment uses an approximate inverse Hessian computed from a small calibration dataset (just 128 samples). The result: 50% unstructured sparsity on GPT-scale models with negligible perplexity increase, no retraining needed, running in under an hour even for 175B-parameter models.

Wanda (Sun et al., 2024) simplified this further with a beautiful insight: instead of computing Hessian inverses, just combine weight magnitude with activation magnitude. A weight that processes small activations contributes less to the output, even if the weight itself is large. The pruning score becomes:

score(wij) = |wij| × ||Xj||2

where Xj is the j-th input feature’s activation norm across the calibration set. Prune the weights with the lowest scores. That’s it — no Hessian, no iterative weight updates, just element-wise multiplication and a sort. Wanda matches SparseGPT’s accuracy on LLaMA models at a fraction of the computational cost.

def wanda_prune(W, X_calib, sparsity):
    """Wanda-style pruning: |weight| × ||activation||₂.

    W: weight matrix (in_features × out_features)
    X_calib: calibration activations (n_samples × in_features)
    """
    # Activation norms: how large is each input feature across the calibration set
    activation_norms = np.linalg.norm(X_calib, axis=0)  # shape: (in_features,)

    # Pruning score: weight magnitude × activation norm
    scores = np.abs(W) * activation_norms[:, None]      # broadcast to W's shape

    # Prune per-row (per output neuron) to maintain balanced sparsity
    W_pruned = W.copy()
    for row in range(W.shape[1]):
        row_scores = scores[:, row]
        threshold = np.percentile(row_scores, sparsity * 100)
        W_pruned[:, row] = np.where(row_scores >= threshold, W[:, row], 0.0)

    return W_pruned

The elegance of Wanda is striking: two norms multiplied together, a percentile threshold, and you’re done. It runs in minutes on models with tens of billions of parameters. The activation norm acts as a data-aware importance signal — features that are consistently small across real inputs simply matter less, regardless of the weight magnitude connecting them.

For hardware acceleration, NVIDIA’s 2:4 structured sparsity has become the practical standard. In every group of 4 contiguous weights, exactly 2 are zeroed — a regular pattern that Ampere and Hopper tensor cores exploit for a guaranteed 2× throughput improvement. Both SparseGPT and Wanda can be adapted to produce 2:4 patterns by selecting the 2 lowest-scoring weights within each group of 4, giving real speedup on real hardware.

The Compression Trinity

Pruning completes a trilogy. The quantization post showed how to reduce the precision of each weight (32-bit → 4-bit). The knowledge distillation post showed how to transfer knowledge into a smaller architecture. This post showed how to reduce the number of weights while keeping the architecture. Three orthogonal axes of compression, each attacking a different dimension of model size.

The power of these three techniques lies in their composability. They stack:

Technique What It Reduces Typical Ratio Retraining? Hardware Needs
Quantization Bits per weight 4–8× Optional (PTQ works) INT4/INT8 support
Pruning Number of weights 2–5× Optional (Wanda works) Sparse kernels or 2:4
Distillation Architecture size 2–10× Yes (full training) None (dense student)
Prune + Quantize Weights × bits 8–40× Minimal Sparse + INT4

A practical deployment recipe: start with a trained model, apply Wanda to prune 50% of weights (2× fewer non-zero values), then quantize the remaining weights to INT4 (8× fewer bits per weight). The combined compression is roughly 16× — a 70B-parameter model that originally needed 140GB of memory now fits in under 9GB, runnable on a single consumer GPU. Add distillation into a smaller architecture, and you can push even further.

train → fine-tune → prune → quantize → distill → deploy

The order matters. Prune before you quantize — quantized weights have reduced precision, making magnitude comparisons less reliable. Distill last, if at all, since it requires the most compute. And always evaluate at each stage: compression is a series of approximations, and errors compound.

Try It: Pruning Explorer

A trained neural network (2 → 16 → 12 → 3) on a classification task. Drag the sparsity slider to remove weights — lines disappear as the smallest weights are pruned. Watch how accuracy holds up far longer than you’d expect. Toggle between unstructured (individual weights vanish) and structured (entire neurons vanish) pruning.

0%
Params: 260/260 Accuracy: 97.3% Sparsity: 0%

Try It: Lottery Ticket Finder

Watch the Iterative Magnitude Pruning process unfold. The network starts fully connected, trains, gets pruned, then the surviving weights are rewound to their initial values and retrained. Compare the lottery ticket’s accuracy against a randomly reinitialized sparse network.

80%
Dense: — Ticket: — Random: —

Connections to the Series

Pruning threads through the entire elementary pipeline — from how we initialize networks to how we deploy them:

Conclusion

Michelangelo reportedly said that sculpting was simple: you just chip away everything that doesn’t look like David. Neural network pruning follows the same philosophy. The trained model is already hiding inside the overparameterized network — pruning reveals it by removing everything that doesn’t matter.

We started with the simplest observation — most weights are near zero — and built up to the Lottery Ticket Hypothesis (sparse subnetworks that match dense accuracy), gradual pruning schedules (integrate pruning into training), and modern one-shot methods like Wanda that prune 70-billion-parameter models in minutes. Along the way, we completed the compression trinity: quantization reduces bit-width, pruning removes weights, distillation shrinks the architecture. Together, they make deployment of massive models practical.

The next time you see a production LLM running on a phone or a consumer GPU, remember: it got there by throwing away most of itself. And it works better for it.

References & Further Reading