Knowledge Distillation from Scratch: Teaching Small Models Everything a Big Model Knows
The Compression Problem
You’ve built a 70-billion-parameter transformer that aces every benchmark. One problem: it costs $0.50 per API call, needs 140 GB of GPU memory, and takes 2 seconds to generate a single token. Your users want it on their phone.
We’ve tackled model compression twice already in this series. LoRA showed us how to freeze most weights and train a tiny low-rank adapter, reducing trainable parameters by 100–1000×. Quantization taught us to shrink numbers from 32-bit floats to 4-bit integers, cutting memory by 4–8×. But there’s a third pillar of compression that we haven’t explored — and it’s the oldest, most intuitive, and arguably most elegant of all three.
Knowledge distillation: train a small model to think like a big one.
The idea goes back to Geoffrey Hinton’s 2015 paper “Distilling the Knowledge in Neural Networks” and it rests on one deceptively simple insight. When a teacher model classifies an image of the digit “2”, it doesn’t just output “this is a 2.” Its full output distribution says something much richer: “It’s 73% a 2, but also 15% a 3, 8% a 7, and 3% a 5.” That information — the probabilities on the “wrong” answers — encodes deep structural knowledge about how digits relate to each other. A 2 looks more like a 3 than it looks like a 0. A 7 and a 2 share that same top-left stroke. This is dark knowledge, and it’s invisible in the hard label “2”.
Distillation is how DistilBERT retained 97% of BERT’s quality at 40% smaller. It’s how TinyBERT compressed BERT by 7.5× while staying within 4% of the original. And it’s how DeepSeek distilled chain-of-thought reasoning from a 671-billion-parameter MoE model into 7B–32B dense models you can run on a single GPU.
Let’s build it from scratch.
Hard Labels Throw Away Knowledge
Standard supervised learning works like this: you give the model an input, it produces a prediction, and you compare that prediction to a hard label — a one-hot vector that’s 1 in one position and 0 everywhere else. The cross-entropy loss pushes the model’s predicted probability toward 1.0 on the correct class and 0.0 on everything else.
Here’s the problem. Hard labels are binary verdicts: “this is a 2” and nothing more. But a well-trained teacher knows so much more than that. Let’s look at what a teacher model actually outputs when it sees a handwritten digit:
import numpy as np
def softmax(logits, T=1.0):
"""Softmax with temperature scaling."""
scaled = logits / T
shifted = scaled - np.max(scaled) # numerical stability
exps = np.exp(shifted)
return exps / np.sum(exps)
# Teacher's raw logits for an image of handwritten "2"
# These are the values BEFORE softmax — the teacher's "raw opinion"
teacher_logits = np.array([
-1.5, # digit 0 — round, no sharp angles like a 2
0.2, # digit 1 — has a vertical stroke, some similarity
5.0, # digit 2 — correct class, high logit
2.1, # digit 3 — similar curves! dark knowledge here
-1.0, # digit 4 — angular, not similar
0.8, # digit 5 — top half has some resemblance
-0.5, # digit 6 — round bottom, some similarity
1.6, # digit 7 — shares top-left stroke! more dark knowledge
-0.7, # digit 8 — roundish, not very similar
0.1, # digit 9 — top loop vaguely like a 2
])
# Hard label: "this is a 2"
hard_label = np.array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0])
# Teacher's soft predictions at standard temperature (T=1)
soft_preds = softmax(teacher_logits, T=1.0)
print("Hard label: ", hard_label)
print("Teacher (T=1): ", [f"{p:.4f}" for p in soft_preds])
# Hard label: [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
# Teacher (T=1): ['0.0013', '0.0073', '0.8846', '0.0487', '0.0022',
# '0.0133', '0.0036', '0.0295', '0.0030', '0.0066']
See the difference? The hard label is a blunt instrument — “it’s a 2, full stop.” But the teacher’s probability distribution is dripping with information. The 4.9% probability on digit 3 says “a 2 looks a bit like a 3.” The 3.0% on digit 7 says “2 and 7 share a structural feature.” Even the tiny 0.13% on digit 0 carries a signal — “a 2 looks nothing like a 0.”
When you train a student on hard labels, all that structural knowledge evaporates. The student learns “this is a 2” but never discovers that 2’s are related to 3’s and 7’s. It has to rediscover those relationships independently from the raw data.
The knowledge in a trained model isn’t just the correct answers. It’s the wrong answers and their magnitudes — the full probability landscape over all classes. This is Geoffrey Hinton’s “dark knowledge”: the information hiding in the probabilities that aren’t the argmax.
But there’s a catch. At standard temperature (T=1), the teacher’s distribution is still quite peaked. Digit 2 gets 88.5%, and the dark knowledge is spread across the remaining 11.5%. You can see hints of structure — digit 3 gets 4.9%, digit 7 gets 3.0% — but the differences between non-dominant classes are subtle and hard to distinguish from noise. To make the dark knowledge vivid, we need to soften the teacher’s distribution.
Temperature Reveals the Dark Knowledge
In our softmax & temperature post, we built the temperature-scaled softmax and showed how higher temperature makes distributions more uniform while lower temperature makes them spikier. The same formula now becomes the key to distillation.
Recall: softmax(z/T) divides every logit by temperature T before applying the exponential. When T = 1, you get the standard softmax. When T > 1, the distribution flattens. The dark knowledge hiding in those near-zero probabilities gets amplified into visible signal.
# Same teacher logits from above — let's see what happens
# as we raise the temperature
for T in [1, 2, 5, 10, 20]:
probs = softmax(teacher_logits, T=T)
# Show only the interesting digits: 2 (correct), 3, 7 (dark knowledge)
print(f"T={T:2d}: P(2)={probs[2]:.4f} P(3)={probs[3]:.4f} "
f"P(7)={probs[7]:.4f} P(0)={probs[0]:.4f}")
# T= 1: P(2)=0.8846 P(3)=0.0487 P(7)=0.0295 P(0)=0.0013
# T= 2: P(2)=0.5189 P(3)=0.1217 P(7)=0.0948 P(0)=0.0201
# T= 5: P(2)=0.2231 P(3)=0.1249 P(7)=0.1130 P(0)=0.0608
# T=10: P(2)=0.1524 P(3)=0.1140 P(7)=0.1085 P(0)=0.0796
# T=20: P(2)=0.1240 P(3)=0.1073 P(7)=0.1046 P(0)=0.0896
Watch what happens as temperature rises:
- T = 1: Digit 2 dominates with 88.5%. You can see hints of structure (4.9% on 3, 3.0% on 7) but the rest is noise.
- T = 2: Digit 2 drops to 52%. Digits 3 (12%) and 7 (9.5%) are clearly distinct from the rest.
- T = 5: Digit 2 is down to 22%. Now you can read the teacher’s full ranking: 2 > 3 > 7 > 5 > 1 ≈ 9. The inter-class structure is vivid.
- T = 10: Nearly flat, but the ranking is preserved. Digit 2 is still highest.
- T = 20: Approaching uniform — too much entropy, too little signal.
The sweet spot is typically T = 3 to T = 10. High enough to reveal the dark knowledge, low enough to preserve the meaningful signal. In practice, T = 4 to T = 6 works well for most classification tasks.
There’s a beautiful connection to physics here. The temperature-scaled softmax is literally the Boltzmann distribution from statistical mechanics: P(state) ∝ exp(-E/kT). At low temperature, a physical system freezes into its lowest-energy state. At high temperature, it explores all states equally. Hinton chose the name “temperature” deliberately — we’re “heating up” the teacher’s output to free the knowledge trapped in the frozen peaks.
The Distillation Loss from First Principles
Now we can define the complete distillation objective. We’re training a student model with two teachers simultaneously: the big model (which provides soft targets) and the ground truth labels (which keep the student honest).
The loss has two components:
- Soft-target loss: KL divergence between the teacher’s and student’s probability distributions, both computed at temperature T
- Hard-label loss: Standard cross-entropy between the student’s predictions (at T = 1) and the true one-hot labels
The combined loss:
Where α controls the balance between the teacher’s guidance and ground-truth supervision. Hinton found that values close to 1.0 (heavily weighting the teacher) work best — the dark knowledge is a richer training signal than a one-hot label.
Notice the T² factor multiplying the KL term. We’ll derive exactly where it comes from in the next section. For now, here’s the full implementation:
def kl_divergence(p, q):
"""KL(p || q) — how much information is lost using q to approximate p."""
# Avoid log(0) with a small epsilon
eps = 1e-8
return np.sum(p * np.log((p + eps) / (q + eps)))
def cross_entropy(targets_onehot, predictions):
"""Standard cross-entropy loss."""
eps = 1e-8
return -np.sum(targets_onehot * np.log(predictions + eps))
def distillation_loss(teacher_logits, student_logits, true_labels, T=4.0, alpha=0.9):
"""
Complete knowledge distillation loss.
teacher_logits: raw logits from the (frozen) teacher — (num_classes,)
student_logits: raw logits from the student — (num_classes,)
true_labels: one-hot ground truth — (num_classes,)
T: temperature for softening distributions
alpha: weight on the soft-target loss (0 to 1)
"""
# Soft targets: both teacher and student at temperature T
p_teacher_soft = softmax(teacher_logits, T=T)
p_student_soft = softmax(student_logits, T=T)
# Hard predictions: student at T=1
p_student_hard = softmax(student_logits, T=1.0)
# Component 1: soft-target KL divergence, scaled by T²
soft_loss = T * T * kl_divergence(p_teacher_soft, p_student_soft)
# Component 2: standard cross-entropy with true labels
hard_loss = cross_entropy(true_labels, p_student_hard)
# Weighted combination
total_loss = alpha * soft_loss + (1 - alpha) * hard_loss
return total_loss
# Example: teacher and student see the same digit "2"
# Student gets the right class but has wrong structural assumptions
# (thinks 5 is similar to 2, doesn't recognize 3's or 7's similarity)
student_logits = np.array([0.5, -0.3, 3.5, 0.2, 0.8, 2.0, 0.1, -0.5, 0.3, 1.2])
true_label = np.array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0])
loss = distillation_loss(teacher_logits, student_logits, true_label, T=4.0, alpha=0.9)
print(f"Distillation loss: {loss:.4f}")
# Distillation loss: 1.0229
# For comparison, standard CE (no teacher):
hard_only = cross_entropy(true_label, softmax(student_logits, T=1.0))
print(f"Hard-label CE only: {hard_only:.4f}")
# Hard-label CE only: 0.4650
The distillation loss is higher because the student is being judged on a harder criterion: it’s not enough to get the right answer, the student has to match the teacher’s entire probability landscape — the right answer, the near misses, and the clear non-matches, all in the correct proportions.
Why both components? The soft targets provide rich inter-class information, but they can be slightly miscalibrated (the teacher isn’t perfect). The hard labels keep the student grounded in the actual truth. In practice, the soft-target term does most of the heavy lifting, which is why α ≈ 0.9 is typical.
The T² Factor: Why Gradients Need Rescaling
That T² multiplier on the soft-target loss isn’t a fudge factor — it falls directly out of the gradient mathematics. Most tutorials mention it in passing or skip it entirely. Let’s derive it properly.
When we compute the gradient of the KL divergence with respect to the student’s logits zj, we need to differentiate through the temperature-scaled softmax. The chain rule gives us:
Two things shrink at high temperature. First, the explicit 1/T factor from differentiating z/T. Second, the difference (pstudent − pteacher) itself gets smaller because both distributions are flatter. The net effect: KL gradients are proportional to 1/T².
Without the T² correction, raising the temperature would simultaneously reveal the dark knowledge and kill the gradient signal. The T² factor restores the gradient magnitude, so the student learns at the same rate regardless of temperature.
Let’s verify this numerically:
def kl_gradient_wrt_student(teacher_logits, student_logits, T):
"""
Gradient of KL(p_teacher^T || p_student^T) w.r.t. student_logits.
Uses the analytical result: dKL/dz_j = (1/T)(p_student_j - p_teacher_j)
"""
p_teacher = softmax(teacher_logits, T=T)
p_student = softmax(student_logits, T=T)
return (1.0 / T) * (p_student - p_teacher)
# Same logits as before
student_logits = np.array([0.5, -0.3, 3.5, 0.2, 0.8, 2.0, 0.1, -0.5, 0.3, 1.2])
# Measure gradient norms at different temperatures
print("T | grad norm (KL) | grad norm (T²·KL) | ratio to T=1")
print("----|-----------------|--------------------|--------------")
norm_at_T1 = np.linalg.norm(kl_gradient_wrt_student(teacher_logits, student_logits, T=1))
for T in [1, 2, 4, 8, 16]:
grad = kl_gradient_wrt_student(teacher_logits, student_logits, T=T)
raw_norm = np.linalg.norm(grad)
scaled_norm = np.linalg.norm(T * T * grad) # multiply by T²
print(f"T={T:2d} | {raw_norm:.6f} | {scaled_norm:.6f} | "
f"{raw_norm / norm_at_T1:.4f}")
# T | grad norm (KL) | grad norm (T²·KL) | ratio to T=1
# ----|-----------------|--------------------|--------------
# T= 1 | 0.298980 | 0.298980 | 1.0000
# T= 2 | 0.127134 | 0.508538 | 0.4252
# T= 4 | 0.031069 | 0.497098 | 0.1039
# T= 8 | 0.007476 | 0.478438 | 0.0250
# T=16 | 0.001837 | 0.470308 | 0.0061
Look at the “grad norm (KL)” column: without the T² factor, the gradient norm drops by over 160× from T=1 to T=16. But the “grad norm (T²·KL)” column stays within a factor of 2 across temperatures. The T² factor restores gradient magnitudes to the same order of magnitude regardless of temperature.
Without T², raising the temperature to T = 16 would reduce your soft-target gradients to 0.6% of their T = 1 magnitude. The student would barely learn from the teacher. The T² correction is not optional — it’s the difference between distillation that works and distillation that silently fails.
The High-Temperature Limit: When KD Becomes Logit Matching
There’s a beautiful mathematical result hiding in the distillation formula. When the temperature gets very large, the KL divergence between softened distributions converges to something remarkably simple.
Here’s the key approximation. When T is large, z/T is small, and exp(z/T) ≈ 1 + z/T (the first-order Taylor expansion). Plug this into the softmax:
where C = number of classes
In this regime, the softmax becomes approximately linear in the logits. And the KL divergence between two approximately-linear distributions reduces to:
That’s mean squared error on the logits. At very high temperature, distillation is just “make the student’s raw logits match the teacher’s raw logits.” No softmax, no probabilities, no KL divergence — just MSE on the numbers before the final activation.
This isn’t just a theoretical curiosity. Some practitioners skip the temperature dance entirely and train directly with a logit-matching objective: L = MSE(z_teacher, z_student). It’s faster (no softmax computation), simpler (no T² scaling to worry about), and in many settings nearly as effective as proper KD.
There’s a deeper connection too. Yuan et al. (2020) showed that knowledge distillation with soft targets is mathematically equivalent to a learned, non-uniform form of label smoothing. Standard label smoothing mixes the hard label with a uniform distribution. KD replaces that uniform distribution with the teacher’s distribution — a data-dependent smoothing pattern that carries genuine structural information about inter-class relationships.
Training a Distilled Student
Theory is beautiful, but does it actually work? Let’s build a teacher and student from scratch and run a controlled experiment. We’ll use a toy 3-class classification problem so we can see everything — every logit, every gradient, every probability.
import numpy as np
# ─── Reproducible setup ───
np.random.seed(42)
# ─── Generate a toy 3-class dataset ───
# Three overlapping clusters — deliberately hard so the teacher's
# soft targets carry meaningful inter-class information
def make_data(n_per_class=200):
X, y = [], []
centers = [(-1.5, -1.0), (1.5, -1.0), (0.0, 1.5)]
for cls, (cx, cy) in enumerate(centers):
pts = np.random.randn(n_per_class, 2) * 0.9 + [cx, cy]
X.append(pts)
y.append(np.full(n_per_class, cls))
return np.vstack(X), np.concatenate(y).astype(int)
X_train, y_train = make_data(200) # 600 samples
X_test, y_test = make_data(100) # 300 samples
def one_hot(labels, num_classes=3):
oh = np.zeros((len(labels), num_classes))
oh[np.arange(len(labels)), labels] = 1.0
return oh
y_train_oh = one_hot(y_train)
y_test_oh = one_hot(y_test)
# ─── Simple MLP helper ───
def relu(x):
return np.maximum(0, x)
def softmax_batch(logits, T=1.0):
scaled = logits / T
shifted = scaled - np.max(scaled, axis=1, keepdims=True)
exps = np.exp(shifted)
return exps / np.sum(exps, axis=1, keepdims=True)
class MLP:
"""Minimal 2-layer MLP: input → hidden (ReLU) → output (logits)."""
def __init__(self, input_dim, hidden_dim, output_dim):
scale1 = np.sqrt(2.0 / input_dim)
scale2 = np.sqrt(2.0 / hidden_dim)
self.W1 = np.random.randn(input_dim, hidden_dim) * scale1
self.b1 = np.zeros(hidden_dim)
self.W2 = np.random.randn(hidden_dim, output_dim) * scale2
self.b2 = np.zeros(output_dim)
def forward(self, X):
self.h_pre = X @ self.W1 + self.b1 # (batch, hidden)
self.h = relu(self.h_pre) # (batch, hidden)
self.logits = self.h @ self.W2 + self.b2 # (batch, output)
return self.logits
def backward(self, X, dlogits, lr=0.01):
batch = X.shape[0]
# Output layer gradients
dW2 = self.h.T @ dlogits / batch
db2 = np.mean(dlogits, axis=0)
# Hidden layer gradients
dh = dlogits @ self.W2.T
dh_pre = dh * (self.h_pre > 0).astype(float) # ReLU backward
dW1 = X.T @ dh_pre / batch
db1 = np.mean(dh_pre, axis=0)
# SGD update
self.W2 -= lr * dW2
self.b2 -= lr * db2
self.W1 -= lr * dW1
self.b1 -= lr * db1
def accuracy(model, X, y):
logits = model.forward(X)
preds = np.argmax(logits, axis=1)
return np.mean(preds == y)
# ─── Train the teacher (big model: 128 hidden units) ───
teacher = MLP(2, 128, 3)
for epoch in range(300):
logits = teacher.forward(X_train)
probs = softmax_batch(logits, T=1.0)
dlogits = probs - y_train_oh # per-sample CE gradient
teacher.backward(X_train, dlogits, lr=0.1)
teacher_acc = accuracy(teacher, X_test, y_test)
print(f"Teacher accuracy: {teacher_acc:.4f}")
# Teacher accuracy: 0.8833 (results vary by random seed)
# ─── Approach 1: Student trained from scratch (hard labels only) ───
student_scratch = MLP(2, 32, 3)
scratch_history = []
for epoch in range(300):
logits = student_scratch.forward(X_train)
probs = softmax_batch(logits, T=1.0)
dlogits = probs - y_train_oh # per-sample CE gradient
student_scratch.backward(X_train, dlogits, lr=0.1)
if epoch % 10 == 0:
scratch_history.append(accuracy(student_scratch, X_test, y_test))
print(f"Student (scratch): {accuracy(student_scratch, X_test, y_test):.4f}")
# ─── Approach 2: Student trained on teacher's hard outputs (SFT-style) ───
# Use teacher's argmax as labels — no soft targets
teacher_hard = np.argmax(teacher.forward(X_train), axis=1)
teacher_hard_oh = one_hot(teacher_hard)
student_hard = MLP(2, 32, 3)
hard_history = []
for epoch in range(300):
logits = student_hard.forward(X_train)
probs = softmax_batch(logits, T=1.0)
dlogits = probs - teacher_hard_oh # per-sample CE gradient
student_hard.backward(X_train, dlogits, lr=0.1)
if epoch % 10 == 0:
hard_history.append(accuracy(student_hard, X_test, y_test))
print(f"Student (hard distill): {accuracy(student_hard, X_test, y_test):.4f}")
# ─── Approach 3: Full knowledge distillation (soft targets, T=4, α=0.9) ───
T = 4.0
alpha = 0.9
teacher_logits_all = teacher.forward(X_train) # (600, 3)
student_kd = MLP(2, 32, 3)
kd_history = []
for epoch in range(300):
student_logits = student_kd.forward(X_train)
# Soft-target gradient: T² · d/dz KL(p_teacher^T || p_student^T)
p_teacher_soft = softmax_batch(teacher_logits_all, T=T)
p_student_soft = softmax_batch(student_logits, T=T)
soft_grad = T * (p_student_soft - p_teacher_soft) # T² absorbed: T·(1/T)·(...) = (...)
# Wait — the 1/T from the softmax derivative and T² give us just T × (p_s - p_t)
# Hard-label gradient: d/dz CE(y, p_student)
p_student_hard = softmax_batch(student_logits, T=1.0)
hard_grad = (p_student_hard - y_train_oh)
# Combined gradient (per-sample, same format as approaches 1 and 2)
dlogits = alpha * soft_grad + (1 - alpha) * hard_grad
student_kd.backward(X_train, dlogits, lr=0.1)
if epoch % 10 == 0:
kd_history.append(accuracy(student_kd, X_test, y_test))
print(f"Student (KD, T=4): {accuracy(student_kd, X_test, y_test):.4f}")
# Teacher accuracy: 0.8833 (results vary by random seed)
# Student (scratch): 0.8433
# Student (hard distill): 0.8500
# Student (KD, T=4): 0.8600
The results tell a clear story. The student network has 4× fewer hidden units than the teacher (32 vs 128), so it can never perfectly match the teacher. But how we train it matters:
- Student from scratch (hard labels only): 84.3% — decent, but limited by the information content of one-hot labels
- Hard distillation (teacher’s argmax as labels): 85.0% — slightly better, because the teacher’s labels are sometimes “better” than the noisy ground truth (the teacher corrects mislabeled examples near decision boundaries)
- Full KD (soft targets at T=4, α=0.9): 86.0% — the best student, closing roughly 40% of the gap to the teacher
The KD student gets 1.7 percentage points more accuracy from the same 32-unit network, purely by learning from the teacher’s soft probability distribution instead of hard labels. On a toy 3-class problem, 1.7 points is modest — with only 3 classes, there are just 2 “dark knowledge” probabilities per sample. But on ImageNet with 1000 classes, the dark knowledge is vastly richer (999 “wrong” probabilities per sample), and the gains are proportionally larger. This is exactly why DistilBERT and TinyBERT (with vocabulary sizes in the tens of thousands) see such dramatic improvements from distillation.
Beyond Output Matching: Feature Distillation
Everything we’ve built so far — the temperature scaling, the KD loss, the T² factor — only uses the teacher’s final output. But a neural network learns useful representations at every layer. The teacher’s intermediate features contain structural knowledge about how it processes inputs, not just what it outputs.
Romero et al. (2015) introduced FitNets, a technique that matches the student’s intermediate representations to the teacher’s. The core idea: pick a “hint layer” in the teacher and a “guided layer” in the student, then add a loss that pushes the student’s hidden activations to match the teacher’s.
There’s a dimension mismatch problem: the teacher’s hidden layer might be 128-dimensional while the student’s is 32-dimensional. The solution is a learned regressor — a small linear projection that maps the student’s features into the teacher’s space before computing the loss.
class FitNetRegressor:
"""
Projects student features into teacher feature space.
student_dim → teacher_dim via a learned linear map.
"""
def __init__(self, student_dim, teacher_dim):
scale = np.sqrt(2.0 / student_dim)
self.W = np.random.randn(student_dim, teacher_dim) * scale
self.b = np.zeros(teacher_dim)
def forward(self, student_features):
"""Project student features into teacher space."""
return student_features @ self.W + self.b
def backward(self, student_features, d_projected, lr=0.01):
"""Update regressor weights and return gradient for student."""
batch = student_features.shape[0]
dW = student_features.T @ d_projected / batch
db = np.mean(d_projected, axis=0)
self.W -= lr * dW
self.b -= lr * db
# Gradient flowing back to the student's hidden layer
return d_projected @ self.W.T
def hint_loss(teacher_features, student_features, regressor):
"""
FitNets hint loss: MSE between teacher features and
projected student features.
teacher_features: hidden activations from teacher — (batch, teacher_dim)
student_features: hidden activations from student — (batch, student_dim)
regressor: FitNetRegressor mapping student → teacher space
"""
projected = regressor.forward(student_features) # (batch, teacher_dim)
diff = projected - teacher_features # (batch, teacher_dim)
loss = np.mean(diff ** 2) # scalar
# Gradient of MSE w.r.t. projected features
d_projected = 2.0 * diff / (diff.shape[0] * diff.shape[1])
return loss, d_projected
# Example: match features from our teacher and student
teacher.forward(X_train[:5])
teacher_feats = teacher.h # (5, 128)
student_kd.forward(X_train[:5])
student_feats = student_kd.h # (5, 32)
reg = FitNetRegressor(32, 128)
hloss, d_proj = hint_loss(teacher_feats, student_feats, reg)
print(f"Hint loss: {hloss:.4f}")
# Hint loss: 2.3451 (high initially — student and teacher features don't match yet)
In practice, feature distillation is used as a pretraining stage: first train the student to match the teacher’s intermediate features (hint training), then fine-tune with the standard KD loss on outputs. This two-stage approach lets the student build good internal representations before worrying about the final predictions.
There are several variants of feature-level distillation:
- Attention transfer: Instead of matching raw hidden states, match the attention maps. The spatial “where the model looks” pattern is often more transferable than the raw activation values.
- Relation-based distillation (RKD): Don’t match individual samples. Instead, match the geometry of the representation space — the pairwise distances and angles between samples. If the teacher’s representation puts “cat” and “lynx” close together, the student should too.
- Contrastive distillation (CRD): Use contrastive learning to transfer the teacher’s structural knowledge. Representations that are similar in the teacher’s space should be similar in the student’s space, regardless of dimensionality.
The taxonomy of distillation methods is organized by what gets transferred:
- Response-based: Match the teacher’s output distribution (what we built in sections 2–7)
- Feature-based: Match intermediate hidden representations (FitNets, attention transfer)
- Relation-based: Match the geometric structure of the representation space (RKD, CRD)
Distillation in the Wild: From DistilBERT to DeepSeek-R1
The theory is elegant. But the real proof is in the models that shipped. Here’s how distillation transformed the landscape of practical NLP and LLM deployment.
DistilBERT (2019)
The paper that proved distillation works at scale. Sanh et al. took BERT-base (110M parameters, 12 layers) and distilled it into DistilBERT (66M parameters, 6 layers) — 40% smaller, 60% faster, retaining 97% of BERT’s performance across GLUE benchmarks.
Their recipe used a triple loss: (1) standard KD on softmax outputs, (2) masked language modeling loss (the original BERT objective), and (3) a cosine embedding loss aligning student and teacher hidden states. A clever initialization trick seeded the student by copying every other layer from the teacher, giving it a massive head start.
TinyBERT (2020)
Jiao et al. pushed the compression further — 7.5× smaller, 9.4× faster, retaining 96.8% of BERT. The key was a two-stage distillation process: first, general distillation on a large unlabeled corpus (teaching the student general language understanding), then task-specific distillation on the downstream task data. TinyBERT distilled everything: attention maps, hidden states, embedding layer, and final predictions. Six layers of the teacher supervised six layers of the student, with learned transformations bridging dimension gaps.
Born-Again Networks (2018)
Furlanello et al. discovered something counterintuitive: distill a model into a student of the same architecture and size, and the student often outperforms the teacher. This is “self-distillation,” and it works because the soft targets act as an implicit regularizer — the student learns a smoother decision boundary than the teacher did from hard labels alone. Multiple generations of self-distillation yield diminishing but consistent gains. The teacher teaches itself to be better.
DeepSeek-R1 Distillation (2025)
The most current and perhaps most striking example. DeepSeek-R1 is a 671-billion-parameter Mixture-of-Experts model trained with RLHF to do chain-of-thought reasoning. DeepSeek distilled its reasoning capability into dense models ranging from 1.5B to 70B parameters.
Intriguingly, they didn’t use traditional KD with temperature-scaled logits. Instead, they used SFT-style distillation: generate thousands of reasoning traces from the R1 teacher (complete chain-of-thought outputs), then fine-tune the student on these traces using standard cross-entropy. The student never sees the teacher’s probability distribution — it just learns to imitate the style of reasoning. This is distillation in its broadest sense: transferring the teacher’s behavior (not its logits) into a smaller model.
The 32B distilled model retained most of the 671B teacher’s reasoning ability while running at a fraction of the cost. DeepSeek-R1-Distill-Qwen-7B, the smallest variant, runs on a single consumer GPU and still shows chain-of-thought reasoning on math problems.
MiniLLM (2023)
When distilling generative language models (not classifiers), the direction of the KL divergence matters. Standard KD minimizes the forward KL: KL(p_teacher ∥ p_student). This encourages the student to be “mean-seeking” — spreading probability mass to cover everything the teacher considers possible, even if that produces incoherent text.
Gu et al. showed that using the reverse KL — KL(p_student ∥ p_teacher) — produces a “mode-seeking” student that generates more focused, coherent text. The student picks one plausible continuation and commits to it, rather than hedging across all possibilities. This matters enormously for text generation, where a student that assigns 20% probability to five different next-words produces gibberish.
Connections Across the Pipeline
Knowledge distillation doesn’t exist in isolation. It weaves through almost every other technique we’ve built in this series.
Quantization + Distillation: In quantization-aware training (QAT), a full-precision teacher supervises a quantized student. The student’s forward pass simulates quantized arithmetic (rounding weights to int8 or int4), but the teacher provides smooth probability targets that help the student learn despite the noisy, discretized weights. QLoRA combines this with LoRA: a 4-bit quantized base model supervised by a full-precision teacher while training tiny adapters. Two compression techniques composing.
MoE → Dense: Sparse MoE models like Mixtral 8×7B have enormous total parameter counts but only activate a fraction per token. Distilling an MoE teacher into a dense student gives you a model with fewer parameters, simpler serving infrastructure (no expert routing, no all-to-all communication), and predictable compute per token. You trade total knowledge for deployment simplicity.
RLHF as Distillation: The reward model training step in RLHF is itself a form of distillation. Human preferences (the “teacher”) are compressed into a scalar reward function (the “student”). The reward model doesn’t reproduce the full richness of human judgment — it distills the judgment into a single number. The entire RLHF pipeline is a chain of distillation: human knowledge → reward model → policy model.
Speculative Decoding: In speculative decoding, the draft model is often a distilled version of the target model. A small, fast student generates candidate tokens that the large teacher verifies in parallel. Better distillation means a higher acceptance rate, which means faster generation. Distillation quality directly translates to inference speed.
Softmax Temperature: The temperature parameter in distillation is the exact same operation as in our softmax post and our decoding strategies post. In decoding, high temperature makes generation more creative. In distillation, high temperature reveals the teacher’s hidden knowledge. Same formula, different purpose.
Try It: The Dark Knowledge Explorer
Panel 1: Temperature Reveals Dark Knowledge
A teacher model classifying a digit "2". Watch how raising the temperature softens the distribution and reveals inter-class structure.
Panel 2: Loss Decomposition
Adjust α to balance soft-target (teacher knowledge) vs. hard-label (ground truth) loss. Higher α means more teacher guidance.
Panel 3: Student Training Curves
Three students with the same 32-unit architecture trained with different objectives. Watch the distilled student converge faster and higher.
References & Further Reading
- Hinton, Vinyals & Dean — “Distilling the Knowledge in Neural Networks” (2015) — The foundational paper that introduced temperature-scaled soft targets and the term “dark knowledge”
- Romero et al. — “FitNets: Hints for Thin Deep Nets” (2015) — Feature-level distillation through intermediate layer matching
- Furlanello et al. — “Born Again Neural Networks” (2018) — Self-distillation: same-size students outperforming their teachers
- Sanh et al. — “DistilBERT, a distilled version of BERT” (2019) — 40% smaller, 60% faster, 97% of BERT’s quality
- Jiao et al. — “TinyBERT: Distilling BERT for Natural Language Understanding” (2020) — Multi-layer distillation compressing BERT by 7.5×
- Yuan et al. — “Revisiting Knowledge Distillation via Label Smoothing Regularization” (2020) — Proves that KD is equivalent to learned, non-uniform label smoothing
- Gu et al. — “MiniLLM: Knowledge Distillation of Large Language Models” (2023) — Reverse KL divergence for distilling generative language models
- Gou et al. — “Knowledge Distillation: A Survey” (2021) — Comprehensive taxonomy of response-based, feature-based, and relation-based methods
- DadOps cross-references: Softmax & Temperature, Loss Functions, LoRA, Quantization, Attention, RLHF, Mixture of Experts, Speculative Decoding, Decoding Strategies