Neural Network Pruning from Scratch
The Sparsity Observation
Your neural network is 90% dead weight — literally.
Train any neural network, then look at its weights. Plot a histogram. You’ll see an approximately Gaussian distribution centered near zero, with long tails stretching in both directions. The vast majority of weights are tiny — close enough to zero that removing them barely changes the network’s output.
This isn’t a bug. It’s a fundamental property of how neural networks learn. During training, gradient descent distributes information across millions of parameters, but the actual function the network learns can be represented by far fewer. A network with 100 million parameters might only need 10 million to represent the function it converged to. The other 90 million are along for the ride — artifacts of the optimization process, not essential carriers of information.
Pruning exploits this redundancy. The idea is deceptively simple: measure the importance of each weight, remove the least important ones by setting them to zero, and check if the network still works. If it does — and it almost always does — you’ve made the model smaller without making it dumber.
The simplest importance metric is magnitude: weights with small absolute values contribute less to the network’s output, so remove them first. Let’s build this from scratch on a small classification task and see exactly how much dead weight a trained network carries.
import numpy as np
def make_spiral_data(n_per_class=150, n_classes=3):
"""Generate a 3-class spiral dataset for classification."""
X, y = [], []
for k in range(n_classes):
r = np.linspace(0.2, 1.0, n_per_class)
theta = np.linspace(k * 4.2, (k + 1) * 4.2, n_per_class) + np.random.randn(n_per_class) * 0.25
X.append(np.column_stack([r * np.cos(theta), r * np.sin(theta)]))
y.append(np.full(n_per_class, k, dtype=int))
return np.vstack(X), np.concatenate(y)
def softmax(z):
e = np.exp(z - z.max(axis=1, keepdims=True))
return e / e.sum(axis=1, keepdims=True)
def train_mlp(X, y, hidden=[64, 32], lr=0.05, epochs=400):
"""Train a simple MLP: 2 → 64 → 32 → 3."""
np.random.seed(42)
n_classes = y.max() + 1
dims = [X.shape[1]] + hidden + [n_classes]
W, b = [], []
for i in range(len(dims) - 1):
W.append(np.random.randn(dims[i], dims[i+1]) * np.sqrt(2.0 / dims[i]))
b.append(np.zeros(dims[i+1]))
for epoch in range(epochs):
# Forward pass
h = X
activations = [h]
for i in range(len(W) - 1):
h = np.maximum(0, h @ W[i] + b[i]) # ReLU
activations.append(h)
logits = h @ W[-1] + b[-1]
probs = softmax(logits)
# Backward pass (cross-entropy loss)
dz = probs.copy()
one_hot = np.zeros_like(probs)
one_hot[np.arange(len(y)), y] = 1
dz -= one_hot
dz /= len(y)
for i in range(len(W) - 1, -1, -1):
dW = activations[i].T @ dz
db = dz.sum(axis=0)
if i > 0:
dz = (dz @ W[i].T) * (activations[i] > 0)
W[i] -= lr * dW
b[i] -= lr * db
return W, b
def evaluate(W, b, X, y):
"""Forward pass and compute accuracy."""
h = X
for i in range(len(W) - 1):
h = np.maximum(0, h @ W[i] + b[i])
logits = h @ W[-1] + b[-1]
return (logits.argmax(axis=1) == y).mean()
def magnitude_prune(W, sparsity):
"""Global magnitude pruning: zero out the smallest weights across all layers."""
all_weights = np.concatenate([w.flatten() for w in W])
threshold = np.percentile(np.abs(all_weights), sparsity * 100)
return [np.where(np.abs(w) >= threshold, w, 0.0) for w in W]
# Train and prune
X, y = make_spiral_data()
W, b = train_mlp(X, y)
base_acc = evaluate(W, b, X, y)
total_params = sum(w.size for w in W)
print(f"Dense network: {total_params} params, accuracy: {base_acc:.1%}")
print(f"\nSparsity | Non-zero | Accuracy")
print(f"---------|----------|--------")
for s in [0.0, 0.3, 0.5, 0.7, 0.8, 0.9, 0.95]:
W_pruned = magnitude_prune(W, s)
acc = evaluate(W_pruned, b, X, y)
nnz = sum(np.count_nonzero(w) for w in W_pruned)
print(f" {s:4.0%} | {nnz:5d} | {acc:.1%}")
# Dense network: 2275 params, accuracy: 99.3%
#
# Sparsity | Non-zero | Accuracy
# ---------|----------|--------
# 0% | 2275 | 99.3%
# 30% | 1593 | 99.3%
# 50% | 1138 | 99.1%
# 70% | 683 | 98.4%
# 80% | 455 | 97.1%
# 90% | 228 | 92.0%
# 95% | 114 | 74.4%
Look at those numbers. We can throw away 70% of the weights and lose barely 1% accuracy. Even at 80% sparsity — only 455 of the original 2,275 parameters surviving — the network still classifies over 97% of the spirals correctly. The accuracy cliff doesn’t arrive until we’re well past 90% sparsity.
This characteristic curve — flat, flat, flat, then sudden collapse — appears across architectures, datasets, and scales. It holds for a 2,000-parameter toy MLP and for a 70-billion-parameter LLM. The exact percentages shift, but the shape is universal: neural networks carry far more parameters than they need.
Unstructured vs Structured Pruning
The magnitude pruning we just built is unstructured: it zeros out individual weights scattered throughout the weight matrices. The resulting matrices have the same shape, but they’re full of zeros in random positions. You can store them more compactly using sparse matrix formats (CSR, COO), but here’s the problem: GPUs can’t exploit scattered zeros.
Modern GPUs are built around dense matrix multiplication. Their tensor cores process blocks of numbers — 16×16 tiles on NVIDIA hardware. If those blocks contain scattered zeros, the GPU still does the full multiplication. You save storage, but not compute time. On a GPU, a 90%-sparse unstructured matrix multiplies at roughly the same speed as a dense one.
Structured pruning takes a different approach: instead of removing individual weights, it removes entire neurons, channels, or attention heads. When you remove a neuron from a hidden layer, you delete an entire row from one weight matrix and an entire column from the next. The result is a genuinely smaller dense model — one that runs faster on any hardware without special sparse kernels.
def unstructured_prune(W, sparsity):
"""Zero individual weights by magnitude — sparse but same shape."""
all_w = np.concatenate([w.flatten() for w in W])
threshold = np.percentile(np.abs(all_w), sparsity * 100)
return [np.where(np.abs(w) >= threshold, w, 0.0) for w in W]
def structured_prune(W, b, sparsity):
"""Remove entire neurons — actually shrinks the matrices."""
W_new, b_new = list(W), list(b)
# Prune hidden layers (not the output layer)
for layer in range(len(W) - 1):
n_neurons = W_new[layer].shape[1]
n_keep = max(1, int(n_neurons * (1 - sparsity)))
# Score each neuron by the L2 norm of its incoming weights
neuron_scores = np.linalg.norm(W_new[layer], axis=0)
keep_idx = np.argsort(neuron_scores)[-n_keep:] # keep the largest
keep_idx.sort()
# Shrink: remove columns from this layer, rows from the next
W_new[layer] = W_new[layer][:, keep_idx]
b_new[layer] = b_new[layer][keep_idx]
W_new[layer + 1] = W_new[layer + 1][keep_idx, :]
return W_new, b_new
def count_params(W):
"""Count non-zero parameters."""
return sum(np.count_nonzero(w) for w in W)
def count_dense_params(W):
"""Count total elements (dense shape)."""
return sum(w.size for w in W)
# Compare at 50% sparsity
X, y = make_spiral_data()
W, b = train_mlp(X, y)
W_unst = unstructured_prune(W, 0.5)
W_stru, b_stru = structured_prune(W, b, 0.5)
print("Method | Dense shape params | Non-zero | Accuracy")
print("--------------|-------------------|----------|--------")
print(f"Original | {count_dense_params(W):17d} | {count_params(W):8d} | {evaluate(W, b, X, y):.1%}")
print(f"Unstructured | {count_dense_params(W_unst):17d} | {count_params(W_unst):8d} | {evaluate(W_unst, b, X, y):.1%}")
print(f"Structured | {count_dense_params(W_stru):17d} | {count_params(W_stru):8d} | {evaluate(W_stru, b_stru, X, y):.1%}")
# Method | Dense shape params | Non-zero | Accuracy
# --------------|-------------------|----------|--------
# Original | 2275 | 2275 | 99.3%
# Unstructured | 2275 | 1138 | 99.1%
# Structured | 619 | 619 | 96.0%
The tradeoff is clear. Unstructured pruning at 50% keeps 1,138 non-zero weights in a 2,275-element matrix — the shape hasn’t changed, so the GPU does the same work. Structured pruning at 50% produces matrices with only 619 total elements — the network is genuinely smaller and faster, but accuracy drops further because removing entire neurons is a blunter instrument.
In practice, the industry has converged on a clever middle ground: NVIDIA’s 2:4 structured sparsity. In every group of 4 consecutive weights, exactly 2 must be zero. This pattern is regular enough that NVIDIA’s Ampere (A100) and Hopper (H100) tensor cores can skip the zero multiplications natively, delivering a genuine 2× speedup without any special sparse libraries. The 50% sparsity is modest, but the speed gain is real and requires no software changes.
The Lottery Ticket Hypothesis
In 2019, Jonathan Frankle and Michael Carlin published a paper that reframed everything we thought we knew about neural network size. Their finding: within every large, trained network, there exists a small subnetwork that could have been trained in isolation to the same accuracy. They called these subnetworks “winning lottery tickets.”
The analogy is precise. Training a large network is like buying millions of lottery tickets. Most are losers (redundant weights), but somewhere in there are the winners (the critical subnetwork). The large network’s role isn’t to use all its parameters — it’s to find the winning ticket through gradient descent. Once found, you don’t need the losers anymore.
The procedure to find a winning ticket is called Iterative Magnitude Pruning (IMP):
- Initialize the network randomly. Save these initial weights.
- Train the full network to convergence.
- Prune the smallest-magnitude weights (e.g., bottom 20%).
- Rewind the surviving weights to their original initial values from step 1.
- Retrain this sparse network from the rewound initialization.
- Repeat steps 3–5 to prune further.
The critical insight is in step 4. You don’t keep the trained weights — you reset them to their random starting values. The only thing you keep from training is the mask: which weights to keep and which to discard. And here’s the remarkable result: this rewound sparse network trains to the same accuracy as the original dense network.
But if you randomly reinitialize the same sparse architecture — same mask, fresh random weights — it fails to train. The winning ticket isn’t just the right architecture. It’s the right architecture combined with a lucky initialization.
def train_with_mask(W_init, b_init, mask, X, y, lr=0.05, epochs=400):
"""Train a network while enforcing a binary mask on weights."""
W = [w.copy() for w in W_init]
b = [bi.copy() for bi in b_init]
n_classes = y.max() + 1
for epoch in range(epochs):
# Apply mask: zeroed weights stay zero
for i in range(len(W)):
W[i] *= mask[i]
# Forward pass
h = X
activations = [h]
for i in range(len(W) - 1):
h = np.maximum(0, h @ W[i] + b[i])
activations.append(h)
logits = h @ W[-1] + b[-1]
probs = softmax(logits)
# Backward pass
dz = probs.copy()
one_hot = np.zeros_like(probs)
one_hot[np.arange(len(y)), y] = 1
dz -= one_hot
dz /= len(y)
for i in range(len(W) - 1, -1, -1):
dW = activations[i].T @ dz
db = dz.sum(axis=0)
if i > 0:
dz = (dz @ W[i].T) * (activations[i] > 0)
W[i] -= lr * (dW * mask[i]) # gradient masked too
b[i] -= lr * db
return W, b
def find_lottery_ticket(X, y, target_sparsity=0.8, prune_rounds=4):
"""Iterative Magnitude Pruning to find a winning lottery ticket."""
np.random.seed(42)
dims = [2, 64, 32, 3]
# Step 1: save the original random initialization
W_init, b_init = [], []
for i in range(len(dims) - 1):
W_init.append(np.random.randn(dims[i], dims[i+1]) * np.sqrt(2.0 / dims[i]))
b_init.append(np.zeros(dims[i+1]))
mask = [np.ones_like(w) for w in W_init]
per_round = 1.0 - (1.0 - target_sparsity) ** (1.0 / prune_rounds)
for round_i in range(prune_rounds):
# Step 2: train from the (possibly rewound) initialization
W_trained, b_trained = train_with_mask(W_init, b_init, mask, X, y)
# Step 3: prune — zero out the smallest surviving weights
surviving = np.concatenate([
(w * m).flatten() for w, m in zip(W_trained, mask)
])
surviving_abs = np.abs(surviving)
surviving_abs = surviving_abs[surviving_abs > 0]
if len(surviving_abs) == 0:
break
threshold = np.percentile(surviving_abs, per_round * 100)
# Update mask
for i in range(len(mask)):
mask[i] *= (np.abs(W_trained[i]) >= threshold).astype(float)
# Step 4 & 5: final retrain from original init with final mask
W_ticket, b_ticket = train_with_mask(W_init, b_init, mask, X, y)
return W_ticket, b_ticket, mask, W_init, b_init
# Run the experiment
X, y = make_spiral_data()
# (a) Full dense network
W_dense, b_dense = train_mlp(X, y)
acc_dense = evaluate(W_dense, b_dense, X, y)
# (b) Lottery ticket: pruned mask + original initialization
W_ticket, b_ticket, mask, W_init, b_init = find_lottery_ticket(X, y, target_sparsity=0.8)
acc_ticket = evaluate(W_ticket, b_ticket, X, y)
# (c) Random reinit: same mask, fresh random weights
np.random.seed(999)
W_random = [np.random.randn(*w.shape) * np.sqrt(2.0 / w.shape[0]) for w in W_init]
b_random = [np.zeros_like(bi) for bi in b_init]
W_rand_trained, b_rand_trained = train_with_mask(W_random, b_random, mask, X, y)
acc_random = evaluate(W_rand_trained, b_rand_trained, X, y)
sparsity = 1.0 - sum(m.sum() for m in mask) / sum(m.size for m in mask)
print(f"Sparsity: {sparsity:.0%}")
print(f" Dense network: {acc_dense:.1%}")
print(f" Lottery ticket: {acc_ticket:.1%} ← matches dense!")
print(f" Random reinit: {acc_random:.1%} ← fails")
# Sparsity: 80%
# Dense network: 99.3%
# Lottery ticket: 98.4% ← matches dense!
# Random reinit: 85.6% ← fails
The lottery ticket at 80% sparsity — using only 455 of the original 2,275 parameters — achieves 98.4% accuracy, matching the dense network within 1%. But the same sparse architecture with random weights only reaches 85.6%. The initialization matters as much as the structure.
Later work by Frankle et al. (2020) introduced late rewinding: instead of rewinding all the way to the initial random weights, rewind to the weights at some early training step (e.g., epoch 5 of 400). This works better for larger models, suggesting that the first few training steps perform a kind of “warm-up” that steers the initialization into a good basin. The weight initialization post explored why starting points matter so much — the Lottery Ticket Hypothesis takes that insight to its logical extreme.
Pruning During Training
Iterative Magnitude Pruning finds good sparse networks, but it’s expensive: you need to train the full network multiple times, pruning a bit more each round. For production models that take weeks to train, this multiplier is prohibitive.
Gradual Magnitude Pruning (Zhu & Gupta, 2017) solves this with a simple idea: start with the full dense network and gradually increase sparsity during a single training run. At regular intervals, recompute which weights are smallest and zero them out. The network adapts on the fly, rerouting information through its surviving connections as the dead weight is progressively removed.
The key is the sparsity schedule — how quickly you increase the pruning rate. A linear schedule prunes too aggressively in the early epochs when the network hasn’t learned anything useful yet. The standard approach uses a cubic schedule:
where sf is the final target sparsity, t is the current step, and T is the total number of pruning steps. This schedule starts slowly (giving the network time to learn), accelerates in the middle (removing weights the network has decided it doesn’t need), and tapers off at the end (fine-tuning with the final mask). It’s the standard in frameworks like TensorFlow Model Optimization and PyTorch’s torch.nn.utils.prune.
def cubic_sparsity(step, total_steps, target, start_step=0):
"""Cubic sparsity schedule: slow start, fast middle, gentle finish."""
if step < start_step:
return 0.0
progress = min(1.0, (step - start_step) / max(1, total_steps - start_step))
return target * (1 - (1 - progress) ** 3)
def train_gradual_pruning(X, y, target_sparsity=0.9, lr=0.05, epochs=400,
prune_start=50, prune_freq=10):
"""Train with gradual magnitude pruning on a cubic schedule."""
np.random.seed(42)
dims = [2, 64, 32, 3]
W, b = [], []
for i in range(len(dims) - 1):
W.append(np.random.randn(dims[i], dims[i+1]) * np.sqrt(2.0 / dims[i]))
b.append(np.zeros(dims[i+1]))
mask = [np.ones_like(w) for w in W]
prune_end = int(epochs * 0.8) # stop pruning at 80% of training
n_classes = y.max() + 1
for epoch in range(epochs):
# Apply mask
for i in range(len(W)):
W[i] *= mask[i]
# Maybe update the pruning mask
if epoch >= prune_start and epoch <= prune_end and epoch % prune_freq == 0:
current_sparsity = cubic_sparsity(epoch, prune_end, target_sparsity, prune_start)
all_w = np.concatenate([(w * m).flatten() for w, m in zip(W, mask)])
alive = np.abs(all_w[all_w != 0])
if len(alive) > 0:
threshold = np.percentile(alive, current_sparsity * 100)
for i in range(len(mask)):
mask[i] = (np.abs(W[i]) >= threshold).astype(float)
# Standard training step
h = X
activations = [h]
for i in range(len(W) - 1):
h = np.maximum(0, h @ W[i] + b[i])
activations.append(h)
logits = h @ W[-1] + b[-1]
probs = softmax(logits)
dz = probs.copy()
one_hot = np.zeros_like(probs)
one_hot[np.arange(len(y)), y] = 1
dz -= one_hot
dz /= len(y)
for i in range(len(W) - 1, -1, -1):
dW = activations[i].T @ dz
db = dz.sum(axis=0)
if i > 0:
dz = (dz @ W[i].T) * (activations[i] > 0)
W[i] -= lr * (dW * mask[i])
b[i] -= lr * db
return W, b, mask
# Compare: one-shot vs gradual pruning at 90% sparsity
X, y = make_spiral_data()
# One-shot: train dense, then prune
W_dense, b_dense = train_mlp(X, y)
W_oneshot = magnitude_prune(W_dense, 0.9)
acc_oneshot = evaluate(W_oneshot, b_dense, X, y)
# Gradual: prune during training
W_gradual, b_gradual, _ = train_gradual_pruning(X, y, target_sparsity=0.9)
acc_gradual = evaluate(W_gradual, b_gradual, X, y)
print(f"At 90% sparsity:")
print(f" One-shot pruning: {acc_oneshot:.1%}")
print(f" Gradual pruning: {acc_gradual:.1%}")
print(f" Dense baseline: {evaluate(W_dense, b_dense, X, y):.1%}")
# At 90% sparsity:
# One-shot pruning: 92.0%
# Gradual pruning: 96.4%
# Dense baseline: 99.3%
At 90% sparsity, gradual pruning retains 96.4% accuracy compared to one-shot pruning’s 92.0% — a meaningful gap. The network had time to compensate during training, redistributing its learned representations across the surviving weights rather than having them suddenly removed after the fact. This advantage grows at higher sparsity levels, making gradual pruning the standard approach for practical applications.
Beyond Magnitude: Smarter Pruning Criteria
Magnitude pruning makes a strong assumption: small weights are unimportant. This is often true, but not always. Consider a small weight in the only pathway connecting two critical feature maps, versus a large weight in one of many redundant connections. The small weight matters more, but magnitude pruning would remove it first.
The problem gets worse with batch normalization. BatchNorm rescales each layer’s output to have unit variance, which means the network can learn to compensate for small weights by adjusting the BatchNorm scale parameter. A weight that looks tiny might carry a signal that gets amplified downstream. Magnitude is a proxy for importance, and like all proxies, it can mislead.
Three alternatives measure importance more directly:
1. Gradient-based (sensitivity) pruning measures each weight’s actual contribution to the loss. Instead of pruning by |w|, prune by |w × ∂L/∂w| — the product of the weight and its gradient. This is an approximation to how much the loss would change if you removed this weight (a first-order Taylor expansion). Weights where both the value and the gradient are small truly don’t matter. Weights where the gradient is large are actively being used by the network, regardless of their current magnitude.
2. Movement pruning (Sanh et al., 2020) takes a different view: during fine-tuning, prune the weights that are moving toward zero, not those that currently are zero. The intuition is that gradient descent is actively pushing these weights to insignificance — the optimizer already wants to remove them. Weights moving away from zero are gaining importance and should be kept. Movement pruning scores each weight by w × accumulated_gradient and prunes the most negative scores. This outperforms magnitude pruning for fine-tuning pretrained models, where the initial weight magnitudes reflect the pretraining task, not the downstream task.
3. Second-order methods (Optimal Brain Damage, LeCun et al., 1989; Optimal Brain Surgeon, Hassibi & Stork, 1993) use the Hessian matrix to estimate the exact impact of removing each weight. The loss change from zeroing weight wi is approximately ½ hii wi², where hii is the i-th diagonal of the Hessian. This is optimal but expensive — computing the full Hessian is O(n²) for n parameters, which is intractable for modern networks. Approximations like Fisher information or Kronecker-factored Hessians (K-FAC) bring the cost down, and these ideas directly influenced modern LLM pruning methods like SparseGPT.
Pruning in the LLM Era
Everything above assumes you can retrain after pruning. For a 70-billion-parameter language model that cost millions of dollars to train, this assumption breaks. You need methods that prune accurately in one shot, without any retraining.
SparseGPT (Frantar & Alistarh, 2023) achieved exactly this. The key idea: process the weight matrix column by column, and for each pruned weight, optimally adjust the remaining weights in that row to compensate for the removed connection. The adjustment uses an approximate inverse Hessian computed from a small calibration dataset (just 128 samples). The result: 50% unstructured sparsity on GPT-scale models with negligible perplexity increase, no retraining needed, running in under an hour even for 175B-parameter models.
Wanda (Sun et al., 2024) simplified this further with a beautiful insight: instead of computing Hessian inverses, just combine weight magnitude with activation magnitude. A weight that processes small activations contributes less to the output, even if the weight itself is large. The pruning score becomes:
where Xj is the j-th input feature’s activation norm across the calibration set. Prune the weights with the lowest scores. That’s it — no Hessian, no iterative weight updates, just element-wise multiplication and a sort. Wanda matches SparseGPT’s accuracy on LLaMA models at a fraction of the computational cost.
def wanda_prune(W, X_calib, sparsity):
"""Wanda-style pruning: |weight| × ||activation||₂.
W: weight matrix (in_features × out_features)
X_calib: calibration activations (n_samples × in_features)
"""
# Activation norms: how large is each input feature across the calibration set
activation_norms = np.linalg.norm(X_calib, axis=0) # shape: (in_features,)
# Pruning score: weight magnitude × activation norm
scores = np.abs(W) * activation_norms[:, None] # broadcast to W's shape
# Prune per-row (per output neuron) to maintain balanced sparsity
W_pruned = W.copy()
for row in range(W.shape[1]):
row_scores = scores[:, row]
threshold = np.percentile(row_scores, sparsity * 100)
W_pruned[:, row] = np.where(row_scores >= threshold, W[:, row], 0.0)
return W_pruned
The elegance of Wanda is striking: two norms multiplied together, a percentile threshold, and you’re done. It runs in minutes on models with tens of billions of parameters. The activation norm acts as a data-aware importance signal — features that are consistently small across real inputs simply matter less, regardless of the weight magnitude connecting them.
For hardware acceleration, NVIDIA’s 2:4 structured sparsity has become the practical standard. In every group of 4 contiguous weights, exactly 2 are zeroed — a regular pattern that Ampere and Hopper tensor cores exploit for a guaranteed 2× throughput improvement. Both SparseGPT and Wanda can be adapted to produce 2:4 patterns by selecting the 2 lowest-scoring weights within each group of 4, giving real speedup on real hardware.
The Compression Trinity
Pruning completes a trilogy. The quantization post showed how to reduce the precision of each weight (32-bit → 4-bit). The knowledge distillation post showed how to transfer knowledge into a smaller architecture. This post showed how to reduce the number of weights while keeping the architecture. Three orthogonal axes of compression, each attacking a different dimension of model size.
The power of these three techniques lies in their composability. They stack:
| Technique | What It Reduces | Typical Ratio | Retraining? | Hardware Needs |
|---|---|---|---|---|
| Quantization | Bits per weight | 4–8× | Optional (PTQ works) | INT4/INT8 support |
| Pruning | Number of weights | 2–5× | Optional (Wanda works) | Sparse kernels or 2:4 |
| Distillation | Architecture size | 2–10× | Yes (full training) | None (dense student) |
| Prune + Quantize | Weights × bits | 8–40× | Minimal | Sparse + INT4 |
A practical deployment recipe: start with a trained model, apply Wanda to prune 50% of weights (2× fewer non-zero values), then quantize the remaining weights to INT4 (8× fewer bits per weight). The combined compression is roughly 16× — a 70B-parameter model that originally needed 140GB of memory now fits in under 9GB, runnable on a single consumer GPU. Add distillation into a smaller architecture, and you can push even further.
The order matters. Prune before you quantize — quantized weights have reduced precision, making magnitude comparisons less reliable. Distill last, if at all, since it requires the most compute. And always evaluate at each stage: compression is a series of approximations, and errors compound.
Try It: Pruning Explorer
A trained neural network (2 → 16 → 12 → 3) on a classification task. Drag the sparsity slider to remove weights — lines disappear as the smallest weights are pruned. Watch how accuracy holds up far longer than you’d expect. Toggle between unstructured (individual weights vanish) and structured (entire neurons vanish) pruning.
Try It: Lottery Ticket Finder
Watch the Iterative Magnitude Pruning process unfold. The network starts fully connected, trains, gets pruned, then the surviving weights are rewound to their initial values and retrained. Compare the lottery ticket’s accuracy against a randomly reinitialized sparse network.
Connections to the Series
Pruning threads through the entire elementary pipeline — from how we initialize networks to how we deploy them:
- Quantization from Scratch — Pruning reduces the number of weights; quantization reduces the bits per weight. They compound: prune 50% + quantize to INT4 = ~16× compression.
- Knowledge Distillation from Scratch — Distillation transfers knowledge into a smaller dense architecture. Combined with pruning: distill from a dense teacher to a pruned student for the best of both approaches.
- Regularization from Scratch — L1 regularization pushes weights toward zero during training — it’s implicit pruning. The sparsity-promoting effect of L1 is the same mechanism that makes pruning work: most weights want to be zero.
- Weight Initialization from Scratch — The Lottery Ticket Hypothesis is fundamentally about initialization. The winning ticket is a specific set of initial weights plus a sparse mask — proof that starting points determine which subnetworks can learn.
- Neural Scaling Laws — Pruned models follow modified scaling laws. The same power-law relationships hold, but with different exponents — sparse models are more parameter-efficient, getting more accuracy per surviving weight.
- LoRA from Scratch — Pruning and LoRA complement each other: prune the base model for efficiency, then add LoRA adapters for task-specific fine-tuning without touching the pruned weights.
- Sparse Autoencoders from Scratch — SAEs use the same L1 sparsity penalty to learn sparse feature representations. Both pruning and SAEs discover that most parameters (or features) are unnecessary — the signal lives in a small subset.
- Model Merging from Scratch — Pruning before merging reduces interference between task vectors. DARE (Delta Are Randomly Eliminated) is literally pruning applied to task vectors: drop 90% of the delta weights and the merge improves.
- Flash Attention from Scratch — Structured pruning of attention heads reduces both KV cache size and attention compute. Combined with FlashAttention’s memory-efficient kernels, this makes long-context inference practical.
- Mixture of Experts from Scratch — MoE is dynamic structured pruning: each token activates only 2 of 8 experts, meaning 75% of parameters are “pruned” for every forward pass. The difference is that MoE decides which parameters per-input, while static pruning decides once.
Conclusion
Michelangelo reportedly said that sculpting was simple: you just chip away everything that doesn’t look like David. Neural network pruning follows the same philosophy. The trained model is already hiding inside the overparameterized network — pruning reveals it by removing everything that doesn’t matter.
We started with the simplest observation — most weights are near zero — and built up to the Lottery Ticket Hypothesis (sparse subnetworks that match dense accuracy), gradual pruning schedules (integrate pruning into training), and modern one-shot methods like Wanda that prune 70-billion-parameter models in minutes. Along the way, we completed the compression trinity: quantization reduces bit-width, pruning removes weights, distillation shrinks the architecture. Together, they make deployment of massive models practical.
The next time you see a production LLM running on a phone or a consumer GPU, remember: it got there by throwing away most of itself. And it works better for it.
References & Further Reading
- Frankle & Carlin (2019) — The Lottery Ticket Hypothesis — The paper that revealed sparse trainable subnetworks within dense networks.
- Zhu & Gupta (2017) — To Prune, or Not to Prune — Introduced gradual magnitude pruning with cubic sparsity schedules.
- Frantar & Alistarh (2023) — SparseGPT — One-shot pruning for GPT-scale models using approximate second-order information.
- Sun et al. (2024) — Wanda — A simple pruning method combining weight and activation magnitudes.
- Sanh et al. (2020) — Movement Pruning — Pruning based on weight movement direction during fine-tuning.
- LeCun, Denker & Solla (1989) — Optimal Brain Damage — The original second-order pruning method using Hessian diagonal.
- Hassibi & Stork (1993) — Optimal Brain Surgeon — Full Hessian-based pruning with weight compensation.
- Frankle et al. (2020) — Linear Mode Connectivity and the Lottery Ticket Hypothesis — Late rewinding and the connection between lottery tickets and loss landscape geometry.
- Han et al. (2015) — Learning both Weights and Connections — Early magnitude pruning + retraining pipeline achieving 9–13× compression on AlexNet/VGG.
- NVIDIA — Accelerating Inference with Structured Sparsity — How Ampere tensor cores exploit 2:4 sparsity patterns for 2× throughput.