← Back to Blog

Multi-Task Learning from Scratch

1. Why Train One Network on Many Jobs?

Every neural network we've built in this series solves one task. A classifier predicts labels. A regressor predicts numbers. Even fancy architectures — attention, KANs, mixture density networks — produce a single type of output from a single loss function.

But what if solving more problems made your model better at each one?

That's the counterintuitive promise of multi-task learning (MTL). In 1997, Rich Caruana showed that training a neural network on multiple related tasks simultaneously — using a shared hidden representation — produced better generalization than training on each task alone. The related tasks act as an implicit regularizer: they force the shared layers to learn features that are useful across many objectives, which tend to be the features that generalize best.

Caruana identified five mechanisms that make this work:

There's even a theoretical result from Baxter (1997): the risk of overfitting shared parameters decreases as O(1/N) where N is the number of tasks. More tasks = stronger regularization of the backbone.

The basic MTL loss is a weighted sum over K task losses:

Ltotal = Σk wk · Lkshared, θk)

where θshared are the shared parameters and each θk holds task-specific parameters. The big question — and the subject of most of this post — is how to set those weights wk.

2. Hard Parameter Sharing

The dominant MTL architecture is hard parameter sharing: all tasks share the same encoder (feature extractor), with task-specific "heads" branching off at the top. It's simple, parameter-efficient, and effective.

Here's a two-task network with a shared backbone and separate regression/classification heads:

import numpy as np

class MultiTaskNet:
    """Two-task network: shared hidden layers, separate output heads."""
    def __init__(self, d_in, d_hidden, seed=42):
        rng = np.random.RandomState(seed)
        scale = lambda fan_in: np.sqrt(2.0 / fan_in)
        # Shared layers
        self.W1 = rng.randn(d_in, d_hidden) * scale(d_in)
        self.b1 = np.zeros(d_hidden)
        self.W2 = rng.randn(d_hidden, d_hidden) * scale(d_hidden)
        self.b2 = np.zeros(d_hidden)
        # Regression head (1 output)
        self.W_reg = rng.randn(d_hidden, 1) * scale(d_hidden)
        self.b_reg = np.zeros(1)
        # Classification head (1 output, sigmoid)
        self.W_cls = rng.randn(d_hidden, 1) * scale(d_hidden)
        self.b_cls = np.zeros(1)

    def forward(self, X):
        # Shared encoder
        h1 = np.maximum(0, X @ self.W1 + self.b1)       # ReLU
        h2 = np.maximum(0, h1 @ self.W2 + self.b2)      # ReLU
        # Task-specific heads branch from shared representation
        y_reg = h2 @ self.W_reg + self.b_reg             # linear
        logits = h2 @ self.W_cls + self.b_cls
        y_cls = 1.0 / (1.0 + np.exp(-np.clip(logits, -500, 500)))
        return y_reg.ravel(), y_cls.ravel(), h2

The key insight is in forward: both heads read from the same h2 representation. Whatever features the shared layers learn must serve both the regression and classification objectives. This is the architectural bottleneck that creates Caruana's inductive bias.

The alternative is soft parameter sharing: each task gets its own full network, with learned connections between corresponding layers. Cross-stitch networks (Misra et al. 2016) insert learnable mixing matrices that let the model discover how much to share at each layer. This is more flexible but uses K times the parameters — in practice, hard sharing wins for most applications.

3. The Weighting Problem

The naive approach is uniform weighting: set all wk = 1 and sum the losses. This works fine when tasks have similar loss scales. But when a regression loss sits in the thousands while a classification loss stays near 1, the regression gradient dominates. The shared layers learn features optimized almost entirely for regression, and the classification head starves.

def train_uniform(model, X, y_reg, y_cls, lr=0.001, steps=500):
    """Train with equal weighting — regression dominates."""
    losses_reg, losses_cls = [], []
    for step in range(steps):
        y_r, y_c, h2 = model.forward(X)
        # Regression: MSE (loss scale ~100s)
        loss_reg = np.mean((y_r - y_reg) ** 2)
        # Classification: binary cross-entropy (loss scale ~0.7)
        eps = 1e-7
        loss_cls = -np.mean(y_cls * np.log(y_c + eps)
                            + (1 - y_cls) * np.log(1 - y_c + eps))
        # Uniform weighting: L = L_reg + L_cls
        # Gradient of L_reg overwhelms gradient of L_cls
        losses_reg.append(loss_reg)
        losses_cls.append(loss_cls)
        # ... backprop and update (omitted for brevity)
    return losses_reg, losses_cls
# Typical result: regression loss drops fast, classification barely moves

The problem is clear: loss_reg might be 200 while loss_cls is 0.7. Their gradients flow back through the same shared layers, and the regression signal drowns out the classification signal. We need smarter weighting.

4. Uncertainty Weighting

Kendall, Gal, and Cipolla (2018) derived an elegant solution: weight each task by its learned homoscedastic uncertainty. The idea comes from the Gaussian likelihood. For a regression task with noise variance σ², the negative log-likelihood is:

L = (1 / 2σ²) · ||y − f(x)||² + log(σ)

The 1/σ² factor automatically downweights noisy tasks (high uncertainty = lower weight), while the log(σ) term prevents the model from cheating by pushing σ to infinity (which would zero out all losses). For two tasks combined:

Ltotal = ½ exp(−s1) · L1 + ½ s1 + ½ exp(−s2) · L2 + ½ s2

where sk = log(σk²) is the learned log-variance — a single scalar per task, optimized alongside the model weights via gradient descent.

class UncertaintyWeighting:
    """Kendall et al. 2018: learn task weights from homoscedastic uncertainty."""
    def __init__(self, n_tasks=2):
        # log(sigma^2) for each task, initialized to 0 (sigma=1)
        self.log_vars = np.zeros(n_tasks)

    def weighted_loss(self, task_losses):
        """Compute uncertainty-weighted total loss.
        task_losses: list of scalar loss values [L_1, L_2, ...]
        """
        total = 0.0
        for k, L_k in enumerate(task_losses):
            precision = np.exp(-self.log_vars[k])   # 1/sigma^2
            total += 0.5 * precision * L_k + 0.5 * self.log_vars[k]
        return total

    def grad_log_vars(self, task_losses):
        """Gradient of total loss w.r.t. log_vars (for updating weights)."""
        grads = np.zeros_like(self.log_vars)
        for k, L_k in enumerate(task_losses):
            precision = np.exp(-self.log_vars[k])
            grads[k] = -0.5 * precision * L_k + 0.5   # dL/d(log_var_k)
        return grads

    def update(self, task_losses, lr=0.01):
        grads = self.grad_log_vars(task_losses)
        self.log_vars -= lr * grads

    def effective_weights(self):
        return 0.5 * np.exp(-self.log_vars)

# Usage:
# uw = UncertaintyWeighting(n_tasks=2)
# uw.update([loss_reg, loss_cls])
# print(uw.effective_weights())  # e.g., [0.003, 0.48] — auto-downweights regression

The beauty of this approach: the model discovers the right task weights during training. If regression loss is 200× larger than classification loss, the learned σreg will be large, and exp(−sreg) will be small — automatically compensating for the scale mismatch. No hand-tuning required.

5. Gradient Conflicts

Even with perfectly weighted losses, MTL faces a deeper problem: gradient conflicts. Two task gradients gA and gB conflict when their cosine similarity is negative:

cos(gA, gB) = (gA · gB) / (||gA|| · ||gB||) < 0

This means the tasks want to push the shared parameters in opposite directions. Following the average gradient (gA + gB)/2 can actually increase the loss for one or both tasks.

Yu et al. (2020) proposed PCGrad ("gradient surgery"): when two task gradients conflict, project each one onto the normal plane of the other, removing the conflicting component while keeping the helpful part.

def pcgrad(task_gradients):
    """PCGrad: project conflicting gradients to remove harmful components.
    task_gradients: list of gradient vectors [g_1, g_2, ...]
    Returns: combined update direction (sum of modified gradients).
    """
    modified = [g.copy() for g in task_gradients]
    n_tasks = len(modified)
    for i in range(n_tasks):
        for j in range(n_tasks):
            if i == j:
                continue
            dot = np.dot(modified[i], task_gradients[j])
            if dot < 0:  # conflict detected
                # Project out the conflicting component
                norm_sq = np.dot(task_gradients[j], task_gradients[j])
                if norm_sq > 1e-12:
                    modified[i] -= (dot / norm_sq) * task_gradients[j]
    # Final update direction: sum of modified gradients
    return sum(modified)

# Example: two 4D gradients that conflict
g_reg = np.array([2.0, -1.0, 0.5, 3.0])    # regression wants to go "northeast"
g_cls = np.array([-1.5, 0.8, 0.3, -2.0])   # classification wants to go "southwest"
print(f"Cosine similarity: {np.dot(g_reg, g_cls) / (np.linalg.norm(g_reg) * np.linalg.norm(g_cls)):.3f}")
# -0.969 — strong conflict!
combined = pcgrad([g_reg, g_cls])
print(f"PCGrad direction: {combined}")
# Surgically removes the conflicting components, yielding a compromise direction

The geometric intuition is clean: projection removes exactly the component of gA that points against gB. After surgery, the modified gA is orthogonal to gB — following it cannot hurt task B. If there's no conflict (cosine similarity ≥ 0), the original gradient is kept unchanged.

6. GradNorm

Chen et al. (2018) took a different approach: instead of fixing conflicts geometrically, dynamically adjust the task weights so that all tasks train at similar rates. If one task is learning much faster than another, its weight should decrease so the slower task gets more gradient signal.

GradNorm works by monitoring gradient norms at the last shared layer:

class GradNorm:
    """GradNorm: balance training rates via gradient norm matching."""
    def __init__(self, n_tasks=2, alpha=1.5):
        self.weights = np.ones(n_tasks)       # task weights w_k
        self.alpha = alpha                     # rate balancing strength
        self.initial_losses = None             # L_k(0) for relative rates

    def update_weights(self, task_losses, task_grad_norms, lr=0.025):
        """Adjust weights to balance gradient norms across tasks.
        task_losses: current loss per task [L_1, L_2]
        task_grad_norms: ||grad_W(w_k * L_k)||_2 per task
        """
        if self.initial_losses is None:
            self.initial_losses = np.array(task_losses, dtype=float)
            return
        # Relative inverse training rate: how fast is each task learning?
        loss_ratios = np.array(task_losses) / (self.initial_losses + 1e-8)
        avg_ratio = np.mean(loss_ratios)
        rel_rates = loss_ratios / (avg_ratio + 1e-8)
        # Target gradient norm per task
        avg_gnorm = np.mean(task_grad_norms)
        target_gnorms = avg_gnorm * (rel_rates ** self.alpha)
        # Gradient of |G_k - target_k| w.r.t. w_k
        for k in range(len(self.weights)):
            grad_w = (task_grad_norms[k] - target_gnorms[k])
            self.weights[k] -= lr * grad_w
        # Renormalize so weights sum to n_tasks
        self.weights = len(self.weights) * self.weights / (np.sum(self.weights) + 1e-8)
        self.weights = np.clip(self.weights, 0.1, 10.0)

# Tasks training at different rates:
# gn.update_weights([200.0, 0.5], [15.3, 0.8])
# gn.weights → e.g., [0.4, 1.6] — boosts the slow classifier

The hyperparameter α controls balancing strength. At α = 0, GradNorm just equalizes gradient norms. As α increases, it pushes harder on tasks that are training slower (higher relative loss). The original paper found α = 1.5 works well across settings.

7. Negative Transfer and Auxiliary Tasks

Negative transfer is the failure mode: sharing parameters with an unrelated task makes both tasks worse than if trained separately. It happens when tasks have fundamentally different optimal feature representations, or when one task's gradients consistently overwhelm the other's.

You can detect it by monitoring the gradient cosine similarity between tasks during training. If the cosine is persistently negative, the tasks are fighting over the shared representation. The fix might be separating them, or using PCGrad to manage the conflict.

Interestingly, negative transfer can be asymmetric: task A helps task B while B hurts A. This means task relatedness isn't a simple binary — it has direction.

Auxiliary tasks flip this asymmetry into a feature. An auxiliary task is a secondary objective you add purely to improve the primary task — you don't care about the auxiliary task's performance at test time. Classic examples:

The rule of thumb: good auxiliary tasks require features that are complementary to the primary task but easier to learn. Task B's easy features help task A through the eavesdropping mechanism.

8. MTL in the Wild

Multi-task learning isn't an academic curiosity — it's the backbone of modern AI systems.

LLMs are multi-task learners. GPT-2's paper was literally titled "Language Models are Unsupervised Multitask Learners." By training on next-token prediction over a massive corpus, the model implicitly learns translation, summarization, question answering, and dozens of other NLP tasks. Google's T5 made this explicit by converting every NLP task into a text-to-text format: one model, one loss function, 20+ tasks.

CLIP trains vision and language encoders jointly via a contrastive loss: match image-text pairs in a shared embedding space. This single objective implicitly teaches the model object recognition, scene understanding, attribute detection, spatial reasoning, and more — MTL in disguise.

Autonomous driving systems use a shared backbone with task-specific heads for lane detection, object detection, trajectory prediction, and depth estimation. These tasks are deeply related — object boundaries inform lane boundaries, depth informs scale — making them ideal for hard parameter sharing.

Recommendation systems jointly predict click probability, watch time, like probability, and share probability. These objectives often conflict (optimizing for clicks alone produces clickbait), so techniques like PCGrad and uncertainty weighting are actively used in production.

Try It: Shared vs Separate Networks

Two classification tasks train on the same 2D inputs. Adjust task relatedness to see when sharing a hidden layer helps vs hurts compared to separate networks.

0.80
Click "Train" to compare shared vs separate networks on two classification tasks.

9. The Full MTL Training Loop

Let's put it all together: a training loop that combines hard parameter sharing with uncertainty weighting and PCGrad.

def train_mtl(model, X, y_reg, y_cls, steps=1000, lr=0.001):
    """Full MTL training loop with uncertainty weighting + PCGrad."""
    uw = UncertaintyWeighting(n_tasks=2)
    history = {'reg': [], 'cls': [], 'w_reg': [], 'w_cls': []}

    for step in range(steps):
        y_r, y_c, h2 = model.forward(X)
        # Compute per-task losses
        loss_reg = np.mean((y_r - y_reg) ** 2)
        eps = 1e-7
        loss_cls = -np.mean(y_cls * np.log(y_c + eps)
                            + (1 - y_cls) * np.log(1 - y_c + eps))
        # Update uncertainty weights
        uw.update([loss_reg, loss_cls], lr=0.01)
        weights = uw.effective_weights()
        # Compute per-task gradients (via backprop, shown conceptually)
        g_reg = compute_grad(model, X, y_reg, task='reg')
        g_cls = compute_grad(model, X, y_cls, task='cls')
        # Apply PCGrad to resolve conflicts
        update_dir = pcgrad([weights[0] * g_reg, weights[1] * g_cls])
        # Update shared parameters
        apply_gradients(model, update_dir, lr=lr)
        # Log
        history['reg'].append(loss_reg)
        history['cls'].append(loss_cls)
        history['w_reg'].append(weights[0])
        history['w_cls'].append(weights[1])
        if step % 200 == 0:
            cos_sim = np.dot(g_reg, g_cls) / (
                np.linalg.norm(g_reg) * np.linalg.norm(g_cls) + eps)
            print(f"Step {step}: L_reg={loss_reg:.3f}, L_cls={loss_cls:.3f}, "
                  f"w=[{weights[0]:.3f}, {weights[1]:.3f}], cos={cos_sim:.3f}")
    return history

The loop follows a clean rhythm: forwardper-task lossesupdate weightsper-task gradientsPCGrad surgeryapply update. The uncertainty weighting adjusts the loss scales automatically, while PCGrad ensures the combined gradient doesn't harm either task.

Try It: Gradient Conflict Arena

Two tasks define loss surfaces over a 2D parameter space. Click anywhere to place the "current parameters" and see each task's gradient. Toggle PCGrad to see how gradient surgery resolves conflicts.

Click on the canvas to place parameters. Arrows show task gradients and the combined update direction.

10. Conclusion

Multi-task learning is one of the oldest ideas in deep learning, and also one of the most relevant. Nearly every modern AI system — from GPT to CLIP to self-driving cars — is a multi-task learner in some form. The core idea is simple: sharing representations between related tasks creates an inductive bias that improves generalization.

The practical challenge is managing the conflicts that arise when tasks disagree. We've seen three families of solutions: uncertainty weighting (learn task weights from Bayesian uncertainty), GradNorm (equalize training rates), and PCGrad (surgically remove conflicting gradient components). Each addresses a different aspect of the problem, and they compose well together.

The next time you train a model, ask: is there a related task I could add? Even a simple auxiliary objective — predicting a cheap side label, reconstructing the input, enforcing a consistency constraint — might give your primary task a free regularization boost. In multi-task learning, more is often more.

References & Further Reading