Domain Adaptation from Scratch
1. Why Models Break on New Data
Your model hits 97% accuracy on the test set. You deploy it. Accuracy drops to 61%. What happened?
Nothing changed about the model. The data changed. A product classifier trained on studio-lit catalog images meets blurry phone photos. A sentiment model trained on Amazon reviews encounters the sarcasm-heavy shorthand of Twitter. A chest X-ray detector trained at Hospital A fails at Hospital B because the scanners produce subtly different contrast profiles. In every case, the model was never broken — it was just never trained on the kind of data it now encounters.
This is domain shift: the gap between the distribution the model trained on (the source domain) and the distribution it encounters in the real world (the target domain). The source domain has labels; the target domain usually doesn't. You can't just retrain because labeling target data is expensive, slow, or sometimes impossible.
Domain adaptation is the family of techniques that bridges this gap. Instead of demanding new labels, we align source and target distributions in feature space so that what the model learned on the source generalizes to the target. In this post, we'll build every major technique from first principles: the theory that explains why it works, the distance metrics that measure the gap, and the algorithms — MMD, CORAL, and DANN — that close it.
2. Types of Domain Shift
Not all distribution shifts are created equal. Understanding which type you're facing determines which solution (if any) can help.
Covariate shift is the most common and most tractable case: the input distribution P(X) changes between domains, but the labeling rule P(Y|X) stays the same. A medical model trained on young patients gets deployed on elderly patients. The features (age, blood pressure, BMI) have different distributions, but the disease-symptom relationship hasn't changed. The same X still maps to the same Y — you just see different X values at deployment.
Label shift (also called prior probability shift) flips the direction: P(Y) changes but P(X|Y) stays the same. A disease becomes seasonal — flu is rare in summer, common in winter — but flu patients still present with the same symptoms regardless of season. The class proportions shifted, not the class-conditional features.
Concept drift is the hardest case: P(Y|X) itself changes. What counted as "spam" in 2010 no longer matches what's spam today. The input features might look similar, but the labeling function has moved. This usually requires new labels and retraining — no amount of feature alignment can fix a shifted decision boundary.
| Shift Type | What Changes | What Stays Same | Adaptable? |
|---|---|---|---|
| Covariate | P(X) | P(Y|X) | Yes — align features |
| Label | P(Y) | P(X|Y) | Yes — reweight classes |
| Concept | P(Y|X) | Possibly P(X) | Hard — needs new labels |
| Full | P(X) and P(Y|X) | Nothing | Very hard |
Domain adaptation primarily targets covariate shift. The rest of this post focuses on that setting: same labeling rule, different input distributions.
3. The Theory — Ben-David's Bound
Before building any algorithm, we need the theory that tells us what to optimize. Ben-David et al. (2007, 2010) proved the foundational result: a tight upper bound on target domain error that decomposes into exactly three terms.
For any hypothesis h in hypothesis class H:
εT(h) ≤ εS(h) + ½ dHΔH(DS, DT) + λ*
Each term tells a different part of the story:
Term 1: Source error εS(h) — how well the model performs on source data. This is the standard training objective. We can measure it and minimize it directly.
Term 2: Domain divergence ½ dHΔH(DS, DT) — how different the two domains look through the lens of hypothesis class H. The HΔH divergence asks: "How well can any pair of hypotheses in H disagree about which domain a sample comes from?" If this is large, source and target are easily distinguishable, and source-trained models are unlikely to transfer. This term is reducible — we can learn features that make the domains look more similar.
Term 3: Adaptability constant λ* = minh*[εS(h*) + εT(h*)] — the error of the ideal joint hypothesis, the best any single hypothesis can achieve on both domains simultaneously. This term is not reducible. If no hypothesis works on both domains (because the labeling functions truly differ), λ* is large and no feature alignment trick can fix it.
This three-term decomposition is the organizing principle for every method that follows. MMD, CORAL, and DANN all reduce Term 2 (domain divergence) while minimizing Term 1 (source error) and hoping that Term 3 (λ*) is small.
4. Measuring Domain Distance — Proxy A-Distance
The HΔH divergence is intractable to compute exactly. But Ben-David et al. gave us a practical proxy: train a binary classifier to distinguish source samples from target samples, then derive the divergence from its error rate.
The Proxy A-distance is defined as:
d̂A = 2(1 − 2 · err)
where err is the classification error of the domain discriminator. If the classifier achieves 50% error (random guessing), d̂A = 0 — the domains are indistinguishable in the current feature space. If it achieves 0% error (perfect separation), d̂A = 2 — the domains are maximally different. The beauty of this metric is its simplicity: train a logistic regression, read off the error, compute a single number.
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
def proxy_a_distance(source_features, target_features):
"""Estimate domain divergence via Proxy A-distance.
Returns d_hat_A in [0, 2]. Higher means more domain shift.
"""
X = np.vstack([source_features, target_features])
y = np.concatenate([
np.zeros(len(source_features)), # source = 0
np.ones(len(target_features)) # target = 1
])
clf = LogisticRegression(max_iter=1000)
cv_acc = cross_val_score(clf, X, y, cv=5, scoring='accuracy')
err = 1.0 - cv_acc.mean()
d_hat_a = 2.0 * (1.0 - 2.0 * err)
return max(0.0, d_hat_a) # clamp to [0, 2]
# Example: source and target are 50-dim features
source = np.random.randn(200, 50)
target = np.random.randn(200, 50) + 0.5 # shifted by 0.5
print(f"Proxy A-distance: {proxy_a_distance(source, target):.3f}")
# Higher values mean easier to distinguish = more domain shift
This single number is a practical diagnostic. Before investing in complex adaptation methods, compute the Proxy A-distance. If it's near zero, your domains are already similar — standard transfer will likely work. If it's near 2, adaptation is critical but may not be sufficient (check λ*).
5. Maximum Mean Discrepancy (MMD)
The Proxy A-distance tells us whether a gap exists. Maximum Mean Discrepancy (MMD), introduced by Gretton et al. (2012), gives us a differentiable distance we can directly minimize.
The idea: map both distributions into a rich feature space via a kernel function, then compare their means. If the means match in a sufficiently expressive kernel space, the distributions are identical. The squared MMD between distributions P and Q is:
MMD²(P, Q) = Ex,x'~P[k(x, x')] − 2 · Ex~P, y~Q[k(x, y)] + Ey,y'~Q[k(y, y')]
Three terms with clear intuition: (1) average similarity within P, (2) average cross-similarity between P and Q (subtracted twice), and (3) average similarity within Q. When P = Q, terms (1) and (3) equal term (2), and MMD² = 0.
The standard kernel choice is the Gaussian RBF: k(x, y) = exp(−‖x − y‖² / (2σ²)). With this kernel, MMD = 0 if and only if P = Q — it captures all moments, not just means.
import numpy as np
def gaussian_kernel(x, y, sigma=1.0):
"""Gaussian RBF kernel between all pairs of rows in x and y."""
xx = np.sum(x ** 2, axis=1, keepdims=True)
yy = np.sum(y ** 2, axis=1, keepdims=True)
dists_sq = xx - 2 * x @ y.T + yy.T
return np.exp(-dists_sq / (2 * sigma ** 2))
def mmd_squared(source, target, sigma=1.0):
"""Unbiased estimate of squared MMD with Gaussian RBF kernel."""
m, n = len(source), len(target)
K_ss = gaussian_kernel(source, source, sigma)
K_tt = gaussian_kernel(target, target, sigma)
K_st = gaussian_kernel(source, target, sigma)
# Unbiased: exclude diagonal terms for within-domain sums
np.fill_diagonal(K_ss, 0)
np.fill_diagonal(K_tt, 0)
term1 = K_ss.sum() / (m * (m - 1)) # E[k(x, x')]
term2 = K_tt.sum() / (n * (n - 1)) # E[k(y, y')]
term3 = K_st.sum() / (m * n) # E[k(x, y)]
return term1 + term2 - 2 * term3
# Example: identical distributions should give MMD near 0
source = np.random.randn(300, 10)
target = np.random.randn(300, 10)
print(f"Same dist MMD^2: {mmd_squared(source, target):.4f}")
# Shifted distribution should give higher MMD
target_shifted = target + 1.0
print(f"Shifted MMD^2: {mmd_squared(source, target_shifted):.4f}")
In practice, Long et al. (2015) used MMD as a loss term in deep networks (Deep Adaptation Networks). The idea: compute MMD between source and target activations at intermediate layers, and add it to the classification loss. The network learns features that are both discriminative (for the task) and domain-invariant (low MMD).
6. CORAL — Aligning Covariance
MMD aligns mean embeddings in kernel space. CORAL (Correlation Alignment), introduced by Sun & Saenko (2016), takes a different approach: align the second-order statistics — the covariance matrices — of source and target features.
Why covariance? Consider a classifier that relies on the correlation between "texture sharpness" and "color saturation" features. In studio photos (source), these might be strongly correlated. In phone photos (target), they might be nearly independent. The classifier's learned decision boundary, which exploits that correlation, breaks on target data. CORAL removes this mismatch by making the covariance structures identical.
The Deep CORAL loss is elegantly simple:
LCORAL = (1 / 4d²) · ‖CS − CT‖2F
where CS and CT are the covariance matrices of source and target activations at a chosen layer, d is the feature dimension, and ‖·‖F is the Frobenius norm. The 1/(4d²) normalization makes the loss independent of feature dimension.
import numpy as np
def coral_loss(source_features, target_features):
"""Compute the CORAL loss between source and target features.
Measures squared Frobenius norm between covariance matrices,
normalized by feature dimension.
"""
d = source_features.shape[1]
# Center both sets of features
source_centered = source_features - source_features.mean(axis=0)
target_centered = target_features - target_features.mean(axis=0)
# Covariance matrices (using n-1 normalization)
n_s = len(source_features)
n_t = len(target_features)
cov_s = (source_centered.T @ source_centered) / (n_s - 1)
cov_t = (target_centered.T @ target_centered) / (n_t - 1)
# Squared Frobenius norm of the difference
diff = cov_s - cov_t
loss = np.sum(diff ** 2) / (4 * d * d)
return loss
# Example: features with different covariance structures
np.random.seed(42)
source = np.random.randn(200, 5) @ np.diag([2, 1, 1, 1, 0.5])
target = np.random.randn(200, 5) @ np.diag([0.5, 1, 1, 1, 2])
print(f"CORAL loss: {coral_loss(source, target):.4f}")
# Same covariance should give near-zero loss
target_same = np.random.randn(200, 5) @ np.diag([2, 1, 1, 1, 0.5])
print(f"Same cov: {coral_loss(source, target_same):.4f}")
The total training objective combines classification loss on labeled source data with CORAL loss on both domains: L = Lclass + λ · LCORAL. No target labels needed. CORAL is popular because it's trivial to implement, adds minimal computation, and often works surprisingly well — sometimes rivaling methods that are far more complex.
Try It: Domain Shift Explorer
Source domain (blue/orange = two classes) is fixed. Drag the Shift slider to move the target domain (gray dots). Train a classifier on source only, then toggle CORAL alignment to see accuracy recover.
7. DANN — Adversarial Domain Adaptation
MMD and CORAL align distributions by explicitly measuring and minimizing a distance. Domain-Adversarial Neural Networks (DANN), proposed by Ganin et al. (2016), take an entirely different approach: use an adversarial game to force domain invariance.
The architecture has three components:
- Feature extractor Gf(x; θf) — maps input x to features z
- Label predictor Gy(z; θy) — classifies source features into task labels
- Domain classifier Gd(z; θd) — predicts whether features came from source or target
The label predictor and domain classifier share the same feature extractor. The adversarial game: the domain classifier tries to tell source from target, while the feature extractor tries to fool it. When the domain classifier drops to 50% accuracy (random guessing), the features are domain-invariant — equally useful for source and target data.
The magic ingredient is the Gradient Reversal Layer (GRL). On the forward pass, it acts as the identity function — features pass through unchanged. On the backward pass, it negates the gradients:
Forward: GRL(z) = z
Backward: ∂L/∂zin = −λ · ∂L/∂zout
This single trick makes the entire minimax optimization trainable with standard backpropagation. The feature extractor receives reversed gradients from the domain classifier, pushing it to produce features that maximize domain confusion — while simultaneously receiving normal gradients from the label predictor, keeping the features useful for classification.
Ganin et al. found that gradually increasing λ from 0 to 1 with a sigmoid schedule stabilizes training: λp = 2 / (1 + exp(−10 · p)) − 1, where p is the training progress from 0 to 1. Early on, the model focuses on learning good task features; gradually, domain adversarial pressure kicks in.
import torch
import torch.nn as nn
from torch.autograd import Function
class GradientReversal(Function):
"""Gradient Reversal Layer: identity forward, negate backward."""
@staticmethod
def forward(ctx, x, lambd):
ctx.lambd = lambd
return x.clone()
@staticmethod
def backward(ctx, grad_output):
return -ctx.lambd * grad_output, None
class DANN(nn.Module):
def __init__(self, input_dim, hidden_dim, num_classes):
super().__init__()
# Shared feature extractor
self.features = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
# Task classifier (source labels)
self.classifier = nn.Linear(hidden_dim, num_classes)
# Domain discriminator (source vs target)
self.domain_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1),
)
def forward(self, x, lambd=1.0):
feats = self.features(x)
class_logits = self.classifier(feats)
# Reverse gradients before domain head
reversed_feats = GradientReversal.apply(feats, lambd)
domain_logits = self.domain_head(reversed_feats)
return class_logits, domain_logits
# Lambda schedule: sigmoid warmup from 0 to 1
def dann_lambda(progress):
"""progress in [0, 1]: fraction of training completed."""
return 2.0 / (1.0 + torch.exp(torch.tensor(-10.0 * progress))) - 1.0
The connection to Ben-David's theory is direct: the feature extractor minimizes εS (Term 1 via the classifier) while the GRL minimizes dH(DS, DT) (Term 2 via domain confusion). DANN operationalizes the bound as a trainable neural network.
Try It: Gradient Reversal Arena
Watch DANN train in real-time. Left: feature space (source=blue, target=red). Right: domain classifier accuracy over training steps. Toggle gradient reversal to see the difference.
8. When Adaptation Fails
Domain adaptation isn't magic. Zhao et al. (2019) proved a striking result: forcing domain-invariant features can actually increase error when label distributions differ across domains.
Their lower bound states that for any feature transformation g and classifier h:
εS(h ∘ g) + εT(h ∘ g) ≥ ½ (dJS(YS, YT) − dJS(ZS, ZT))²
If you force perfect feature alignment (dJS(ZS, ZT) = 0), the minimum achievable joint error becomes ½ dJS(YS, YT)². When label distributions differ significantly, making features indistinguishable destroys information the classifier needs.
The intuition: imagine source data is all cats and dogs (50/50), but target data is all cats. Forcing the features to be domain-invariant means the model can't tell that target domain has no dogs — so it keeps predicting dogs on cat images, dragging accuracy down.
Before applying any adaptation method, a simple diagnostic can save you from negative transfer:
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
def should_adapt(source_features, target_features, source_labels,
target_labels_estimate=None):
"""Heuristic check: is domain adaptation likely to help?
Returns a dict with proxy A-distance and recommendation.
"""
# 1. Measure domain gap via Proxy A-distance
X = np.vstack([source_features, target_features])
y_domain = np.concatenate([
np.zeros(len(source_features)),
np.ones(len(target_features))
])
clf = LogisticRegression(max_iter=1000)
err = 1.0 - cross_val_score(clf, X, y_domain, cv=5).mean()
pad = max(0.0, 2.0 * (1.0 - 2.0 * err))
# 2. Check label distribution similarity (if estimates available)
label_warning = False
if target_labels_estimate is not None:
source_dist = np.bincount(source_labels) / len(source_labels)
target_dist = np.bincount(target_labels_estimate) / len(target_labels_estimate)
min_len = min(len(source_dist), len(target_dist))
kl_approx = np.sum(source_dist[:min_len] *
np.log(source_dist[:min_len] /
(target_dist[:min_len] + 1e-10) + 1e-10))
label_warning = kl_approx > 0.5
# 3. Recommendation
if pad < 0.3:
rec = "Low shift. Standard transfer likely sufficient."
elif label_warning:
rec = "Shift detected but label distributions differ. Adapt cautiously."
else:
rec = "Significant shift with similar labels. Adaptation recommended."
return {"proxy_a_distance": pad, "recommendation": rec}
The rule of thumb: adapt when domains differ in input distribution but share similar label distributions and task structure. When concept drift is present or label distributions diverge, adaptation can hurt more than it helps.
9. The Full Adaptation Pipeline
Let's put everything together into a complete training loop that combines classification loss on source data with CORAL alignment and DANN-style domain adversarial training.
import torch
import torch.nn as nn
def train_adapted_model(model, source_loader, target_loader,
num_epochs=50, lr=0.001):
"""Full domain adaptation training loop with CORAL + DANN."""
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
class_loss_fn = nn.CrossEntropyLoss()
domain_loss_fn = nn.BCEWithLogitsLoss()
for epoch in range(num_epochs):
progress = epoch / num_epochs
lambd = 2.0 / (1.0 + float(torch.exp(torch.tensor(-10.0 * progress)))) - 1.0
for (src_x, src_y), (tgt_x, _) in zip(source_loader, target_loader):
# Forward pass through shared feature extractor
src_feats = model.features(src_x)
tgt_feats = model.features(tgt_x)
# Task loss: classification on source only
class_logits = model.classifier(src_feats)
loss_class = class_loss_fn(class_logits, src_y)
# CORAL loss: align covariance matrices
src_centered = src_feats - src_feats.mean(dim=0)
tgt_centered = tgt_feats - tgt_feats.mean(dim=0)
cov_s = (src_centered.T @ src_centered) / (len(src_x) - 1)
cov_t = (tgt_centered.T @ tgt_centered) / (len(tgt_x) - 1)
d = src_feats.shape[1]
loss_coral = torch.sum((cov_s - cov_t) ** 2) / (4 * d * d)
# Domain loss: adversarial with gradient reversal
src_domain = model.domain_head(
GradientReversal.apply(src_feats, lambd))
tgt_domain = model.domain_head(
GradientReversal.apply(tgt_feats, lambd))
labels_src = torch.zeros(len(src_x), 1)
labels_tgt = torch.ones(len(tgt_x), 1)
loss_domain = (domain_loss_fn(src_domain, labels_src) +
domain_loss_fn(tgt_domain, labels_tgt))
# Combined loss
loss = loss_class + 0.1 * loss_coral + loss_domain
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}: class={loss_class:.3f} "
f"coral={loss_coral:.3f} domain={loss_domain:.3f} "
f"lambda={lambd:.2f}")
The three losses work in concert: loss_class keeps the model accurate on the task, loss_coral aligns covariance structures, and loss_domain with gradient reversal forces domain-invariant features. The λ schedule gradually increases adversarial pressure as the model stabilizes.
10. Conclusion
Domain adaptation rests on a beautiful theoretical foundation. Ben-David's bound decomposes target error into three intuitive terms — source error, domain divergence, and adaptability — and every method we built maps directly to reducing one of those terms. MMD matches distributions by comparing kernel mean embeddings. CORAL aligns covariance structures with a single matrix norm. DANN uses a gradient reversal trick to force domain-invariant features through adversarial training.
But Zhao et al.'s impossibility result reminds us that adaptation isn't free. When label distributions differ across domains, forcing invariance can destroy the very information the classifier needs. The practical lesson: always check whether the domains are compatible before blindly applying adaptation.
Modern foundation models have changed the landscape. By pre-training on internet-scale data spanning thousands of implicit domains, models like CLIP and GPT-4 arrive with features that are already more domain-robust than anything a single-dataset model could learn. But the Ben-David bound still applies — deploy a medical imaging model trained on one hospital's scanner to another, and you'll find that even the largest foundation models can stumble. Domain adaptation isn't a historical curiosity; it's the bridge between laboratory accuracy and real-world performance.
References & Further Reading
- Ben-David, Blitzer, Crammer, Kulesza, Pereira & Vaughan (2010) — A Theory of Learning from Different Domains — the foundational bound that motivates all domain adaptation methods
- Gretton, Borgwardt, Rasch, Schölkopf & Smola (2012) — A Kernel Two-Sample Test — the theoretical foundation for Maximum Mean Discrepancy
- Sun & Saenko (2016) — Deep CORAL: Correlation Alignment for Deep Domain Adaptation — aligning second-order statistics for simple, effective adaptation
- Ganin, Ustinova, Ajakan, Germain, Larochelle, Laviolette, Marchand & Lempitsky (2016) — Domain-Adversarial Training of Neural Networks — the gradient reversal layer and adversarial domain adaptation
- Long, Cao, Wang & Jordan (2015) — Learning Transferable Features with Deep Adaptation Networks — applying multi-kernel MMD to intermediate network layers
- Zhao, des Combes, Zhang & Gordon (2019) — On Learning Invariant Representations for Domain Adaptation — the impossibility result showing when invariance hurts
- Shimodaira (2000) — Improving Predictive Inference Under Covariate Shift — early formalization of covariate shift and importance weighting
- Wilson & Cook (2020) — A Survey of Unsupervised Deep Domain Adaptation — comprehensive overview of deep DA methods and taxonomy