← Back to Blog

Federated Learning from Scratch

1. Why Centralized Training Fails

You want to build a medical AI that detects tumors from MRI scans. The best approach: collect scans from 50 hospitals, pool them on a GPU cluster, and train one massive model. One problem — you can't. HIPAA prohibits sharing patient data across institutions. GDPR requires data to stay in its jurisdiction. And even if the law allowed it, hospitals don't trust each other or you. Your 50-hospital super-dataset doesn't exist and never will.

But federated learning lets you train on it anyway — without any hospital sending a single scan to anyone.

Three barriers make centralized training impossible in practice:

  1. Privacy and regulation. GDPR, HIPAA, CCPA, and China's PIPL all restrict data movement. Even within a single company, data silos exist between divisions.
  2. Communication bandwidth. A fleet of 10 million phones generates terabytes of data daily. Uploading it all to the cloud over metered mobile connections is impractical.
  3. Data sovereignty and trust. Organizations won't share competitive data. Users don't want their personal data on someone else's servers.

The federated learning paradigm flips the script: keep data where it is, bring the model to the data, and only share model updates (gradients or weights). Two main settings exist: cross-device (millions of mobile phones, each with tiny datasets and unreliable connections) and cross-silo (a handful of hospitals or banks, each with large datasets and reliable connections).

The key promise: the final model should be as good as if you had trained on the pooled data — without ever creating the pool. Let's see why that promise matters:

import numpy as np

def sigmoid(z):
    return 1 / (1 + np.exp(-np.clip(z, -500, 500)))

# Five hospitals, each with 40 patients in different feature regions
np.random.seed(42)
true_w = np.array([0.8, -1.2])  # true population-level relationship

hospitals = []
for i in range(5):
    rng = np.random.RandomState(i)
    shift = rng.randn(2) * 1.5   # each hospital's patient population differs
    X = rng.randn(40, 2) + shift
    y = (sigmoid(X @ true_w) > rng.random(40)).astype(float)
    hospitals.append((X, y))

# Pool all data — the dream scenario, but illegal
X_all = np.vstack([h[0] for h in hospitals])
y_all = np.concatenate([h[1] for h in hospitals])

def train(X, y, lr=0.5, epochs=100):
    w = np.zeros(2)
    for _ in range(epochs):
        p = sigmoid(X @ w)
        grad = X.T @ (p - y) / len(y)
        w -= lr * grad
    return w

def acc(X, y, w):
    return ((sigmoid(X @ w) > 0.5) == y).mean()

w_pooled = train(X_all, y_all)
print(f"Pooled (200 patients): {acc(X_all, y_all, w_pooled):.0%}")

for i, (X, y) in enumerate(hospitals):
    w_local = train(X, y)
    print(f"Hospital {i} (40 patients): {acc(X_all, y_all, w_local):.0%}")
# Pooled (200 patients): 85%   — sees the full picture
# Hospital 0 (40 patients): 67%  — limited, biased view
# Hospital 1 (40 patients): 63%
# Hospital 2 (40 patients): 70%
# Hospital 3 (40 patients): 61%
# Hospital 4 (40 patients): 66%

Every hospital's local model underperforms the pooled model by 15–24 percentage points. Each one only sees its own slice of the patient population. Can we close that gap without any hospital sharing a single record?

2. The FedAvg Algorithm

In 2017, McMahan et al. at Google proposed Federated Averaging (FedAvg), the algorithm that launched federated learning as a field. The idea is elegant: instead of moving data to the model, move the model to the data.

The FedAvg loop works as follows:

  1. Server initializes a global model w0.
  2. Each round t: server selects a random fraction C of K clients and sends them the current global model wt.
  3. Each selected client k trains the model on its local data for E local epochs using SGD, producing local model wtk.
  4. Clients send their updated models (or the deltas Δw = wtk − wt) back to the server.
  5. Server aggregates: wt+1 = Σk (nk / n) · wtk, a weighted average proportional to each client's dataset size.

Why does averaging models work? Under the IID assumption (each client's data is drawn from the same distribution), each client's gradient is an unbiased estimator of the global gradient. Averaging these estimators reduces variance, converging toward the same optimum as centralized training.

The key tradeoff is between local epochs E and communication cost. More local epochs means fewer communication rounds (cheaper) but greater client drift — each client wanders further from the global objective before corrections are applied. FedSGD is the special case where E=1: minimal drift but maximum communication. FedAvg with E>1 is the practical choice because communication, not computation, is the bottleneck.

# Continuing with hospitals, X_all, y_all, sigmoid, acc from above

def fedavg(hospitals, rounds=30, local_epochs=5, lr=0.5):
    """Federated Averaging — McMahan et al. 2017."""
    w_global = np.zeros(2)
    history = []

    for t in range(rounds):
        updates = []
        sizes = []
        for X_k, y_k in hospitals:
            w_local = w_global.copy()

            # Local training: E epochs of gradient descent
            for _ in range(local_epochs):
                p = sigmoid(X_k @ w_local)
                grad = X_k.T @ (p - y_k) / len(y_k)
                w_local -= lr * grad

            updates.append(w_local)
            sizes.append(len(y_k))

        # Weighted aggregation: each client weighted by dataset size
        sizes = np.array(sizes, dtype=float)
        sizes /= sizes.sum()
        w_global = sum(s * u for s, u in zip(sizes, updates))
        history.append(acc(X_all, y_all, w_global))

    return w_global, history

w_fed5, hist5 = fedavg(hospitals, rounds=30, local_epochs=5)
w_fed1, hist1 = fedavg(hospitals, rounds=30, local_epochs=1)  # FedSGD

print(f"FedAvg  (E=5, 30 rounds): {acc(X_all, y_all, w_fed5):.0%}")
print(f"FedSGD  (E=1, 30 rounds): {acc(X_all, y_all, w_fed1):.0%}")
print(f"Centralized (pooled):     {acc(X_all, y_all, w_pooled):.0%}")
# FedAvg  (E=5, 30 rounds): 84% — close to centralized!
# FedSGD  (E=1, 30 rounds): 82% — converges but costs 5x communication
# Centralized (pooled):     85% — the upper bound
# No hospital shared a single patient record.

FedAvg reaches 84% accuracy — within one point of the centralized upper bound — while each hospital's data never leaves its walls. FedSGD achieves similar quality but requires five times as many parameter exchanges (one per round per epoch vs. one per round). This is the core insight: local computation substitutes for communication.

Try It: Federated Training Simulator

Watch FedAvg train a model across distributed clients. Each small plot shows a client's local data with the global decision boundary. The chart below tracks accuracy over rounds.

Click Run to start federated training.

3. Communication Efficiency

In cross-device federated learning, millions of phones communicate over slow, metered mobile connections. A modern neural network has millions of parameters — sending full model updates every round is prohibitively expensive. If each of 10 million devices sends a 100MB model update, that's a petabyte per round.

Top-k gradient sparsification (Aji & Heafield 2017) offers a simple fix: instead of sending the full gradient vector, each client sends only the k largest components by magnitude. The rest are accumulated in a local error feedback buffer and added to the next round's gradient. This preserves convergence guarantees while slashing bandwidth by 90–99%.

Quantization (Alistarh et al. 2017) takes a different angle: instead of reducing the number of components sent, reduce their precision. SignSGD sends just 1 bit per component (the sign of the gradient). QSGD uses stochastic rounding to reduce 32-bit floats to 8 or fewer bits.

Federated dropout (Caldas et al. 2018) reduces both computation and communication: each client trains a random subnetwork (a subset of neurons per layer) rather than the full model. Different clients train different submodels, and the server reconstructs the full model from partial updates.

def fedavg_sparse(hospitals, rounds=30, local_epochs=5, lr=0.5, top_k=1):
    """FedAvg with top-k sparsification and error feedback."""
    w_global = np.zeros(2)
    residuals = [np.zeros(2) for _ in hospitals]  # error feedback buffers
    bytes_sent = 0

    for t in range(rounds):
        updates, sizes = [], []
        for i, (X_k, y_k) in enumerate(hospitals):
            w_local = w_global.copy()
            for _ in range(local_epochs):
                p = sigmoid(X_k @ w_local)
                grad = X_k.T @ (p - y_k) / len(y_k)
                w_local -= lr * grad

            # Delta + accumulated residual from previous rounds
            delta = (w_local - w_global) + residuals[i]

            # Keep only top-k components by magnitude
            sparse = np.zeros_like(delta)
            top_idx = np.argsort(np.abs(delta))[-top_k:]
            sparse[top_idx] = delta[top_idx]

            residuals[i] = delta - sparse  # save what wasn't sent
            updates.append(sparse)
            sizes.append(len(y_k))
            bytes_sent += top_k * 8  # only k floats transmitted

        sizes = np.array(sizes, dtype=float)
        sizes /= sizes.sum()
        w_global += sum(s * u for s, u in zip(sizes, updates))

    final_acc = acc(X_all, y_all, w_global)
    return final_acc, bytes_sent

full_acc, full_bytes = fedavg_sparse(hospitals, top_k=2)  # full (2 params)
sp_acc, sp_bytes = fedavg_sparse(hospitals, top_k=1)      # 50% sparse

print(f"Full updates:  {full_acc:.0%} accuracy, {full_bytes:,} bytes")
print(f"50% sparse:    {sp_acc:.0%} accuracy, {sp_bytes:,} bytes")
# Full updates:  84% accuracy, 2,400 bytes
# 50% sparse:    82% accuracy, 1,200 bytes — half the bandwidth!
# (With a real 100M-param model, this saves gigabytes per round)

The error feedback mechanism is crucial: without it, the unsent gradient components are lost forever, and the model converges to a worse solution. With error feedback, the residuals accumulate until they're large enough to be selected, ensuring nothing is permanently lost — convergence is preserved, just spread over more rounds.

4. The Non-IID Data Problem

Everything above assumed IID data — each client's data is drawn from the same distribution. In practice, this assumption almost never holds. Phone users type different words. Hospitals see different diseases. Banks serve different customer demographics. This data heterogeneity is the silent killer of federated learning.

Kairouz et al. (2021) identified five types of non-IID data in federated settings:

  1. Label distribution skew — clients have different class proportions (one hospital sees 90% lung cancer, another 90% breast cancer)
  2. Feature distribution skew — same labels, different features (hospitals use different imaging equipment)
  3. Quantity skew — clients have vastly different dataset sizes (power users vs. casual users)
  4. Concept shift — the same features map to different labels across clients
  5. Temporal skew — data arrives at different times on different clients

Non-IID data breaks FedAvg because local gradients point in different directions. Each client optimizes for its own data distribution, drifting away from the global objective. When the server averages these divergent models, the result is a poor compromise for everyone. This phenomenon is called client drift. In extreme cases — say, each client has data from only one class — FedAvg can diverge entirely.

FedProx (Li et al. 2020) fixes this by adding a proximal term to each client's local loss: μ/2 · ||w − wglobal||². This penalizes the local model from drifting too far from the global model — essentially a leash that keeps clients anchored. If you've read our continual learning post, this should look familiar: it's the same idea as Elastic Weight Consolidation (EWC), where important weights are penalized from changing. In federated learning, all weights are penalized from deviating from the global model.

SCAFFOLD (Karimireddy et al. 2020) takes a more principled approach using control variates. Each client maintains a correction term that estimates the difference between its local gradient and the global gradient. The corrected update provably eliminates client drift and converges as fast as centralized SGD, regardless of data heterogeneity — at the cost of 2x communication per round.

# Create non-IID hospitals: skew each hospital's label distribution
noniid_hospitals = []
for i in range(5):
    r = np.random.RandomState(i + 100)
    X = r.randn(40, 2)
    p_true = sigmoid(X @ true_w)
    # Skew: hospitals 0-1 mostly class 0, hospitals 3-4 mostly class 1
    bias = (i - 2) * 0.7
    y = np.clip((p_true + bias > r.random(40)).astype(float), 0, 1)
    noniid_hospitals.append((X, y))

def fedprox(hospitals, rounds=30, local_epochs=5, lr=0.5, mu=0.0):
    """FedAvg with optional FedProx proximal term (mu=0 is plain FedAvg)."""
    w_global = np.zeros(2)
    history = []

    for t in range(rounds):
        updates, sizes = [], []
        for X_k, y_k in hospitals:
            w_local = w_global.copy()
            for _ in range(local_epochs):
                p = sigmoid(X_k @ w_local)
                grad = X_k.T @ (p - y_k) / len(y_k)
                grad += mu * (w_local - w_global)  # proximal term
                w_local -= lr * grad
            updates.append(w_local)
            sizes.append(len(y_k))

        sizes = np.array(sizes, dtype=float)
        sizes /= sizes.sum()
        w_global = sum(s * u for s, u in zip(sizes, updates))
        history.append(acc(X_all, y_all, w_global))

    return history

hist_avg = fedprox(noniid_hospitals, mu=0.0)   # plain FedAvg
hist_prox = fedprox(noniid_hospitals, mu=0.1)  # FedProx

print(f"Non-IID FedAvg:  final acc = {hist_avg[-1]:.0%}")
print(f"Non-IID FedProx: final acc = {hist_prox[-1]:.0%}")
# Non-IID FedAvg:  final acc = 71% — client drift hurts
# Non-IID FedProx: final acc = 79% — proximal term anchors clients

The proximal term costs almost nothing computationally — one extra vector subtraction per gradient step — but recovers 8 percentage points of accuracy. The intuition: without the leash, each client sprints toward its own local optimum, and the average of five different optima is nobody's optimum. With FedProx, clients take constrained steps that balance local fit with global consistency.

Try It: Non-IID Data Explorer

Explore how label skew distributes classes unevenly across clients. The left panel shows each client's class balance; the right panel compares convergence. With this simple linear model, the convergence gap is subtle — in practice with deep networks, the gap between Centralized (green), FedAvg (blue), and FedProx (orange) grows dramatically.

Drag the skew slider to see how non-IID data affects convergence.

5. Secure Aggregation

Federated learning promises privacy because raw data never leaves the client. But here's the uncomfortable truth: model updates can leak private information. Zhu et al. (2019) demonstrated gradient inversion attacks that can reconstruct training images nearly pixel-perfectly from a client's gradient update alone. A malicious or curious server that receives your gradient can, in principle, reconstruct the exact images you trained on.

This means a federated learning server that sees individual client updates has the theoretical ability to violate the very privacy it was supposed to protect. We need a mechanism that lets the server compute the aggregate of all updates without ever seeing any individual update.

Secure aggregation (Bonawitz et al. 2017) achieves exactly this through pairwise masking. The protocol works as follows: each pair of clients (i, j) agrees on a random mask sij (using a shared secret from Diffie-Hellman key exchange). Client i adds sij to its update; client j subtracts sij. When the server sums all masked updates, the masks cancel perfectly:

Σi (wi + Σj>i sij − Σj<i sji) = Σi wi

The server gets the correct aggregate but learns nothing about any individual update. To handle client dropout (phones go offline unpredictably), each client splits its mask into shares using Shamir's secret sharing. If a client drops, the surviving clients reconstruct the missing mask from their shares and cancel it from the aggregate.

def secure_aggregate(client_updates, seed=42):
    """Simplified secure aggregation with pairwise masking."""
    K = len(client_updates)
    dim = len(client_updates[0])

    # Each client's masked update starts as its true update
    masked = [u.copy() for u in client_updates]
    rng = np.random.RandomState(seed)

    for i in range(K):
        for j in range(i + 1, K):
            # Clients i,j agree on a random mask (via shared PRNG seed)
            mask = rng.randn(dim) * 10  # large random noise
            masked[i] += mask   # client i adds mask
            masked[j] -= mask   # client j subtracts mask

    print("Individual masked updates (appear random):")
    for i, m in enumerate(masked):
        print(f"  Client {i}: [{m[0]:+7.2f}, {m[1]:+7.2f}]")

    true_agg = sum(client_updates)
    masked_agg = sum(masked)
    print(f"\nTrue aggregate:   [{true_agg[0]:.4f}, {true_agg[1]:.4f}]")
    print(f"Masked aggregate: [{masked_agg[0]:.4f}, {masked_agg[1]:.4f}]")
    print(f"Exact match: {np.allclose(true_agg, masked_agg)}")

# 4 clients with small, sensitive model updates
updates = [np.array([0.12, -0.05]), np.array([-0.08, 0.15]),
           np.array([0.05, 0.03]),  np.array([-0.03, -0.10])]
secure_aggregate(updates)
# Individual masked updates (appear random):
#   Client 0: [ +18.34,  -12.71]  — looks nothing like [0.12, -0.05]
#   Client 1: [  -5.21,  +22.38]
#   Client 2: [  -8.09,   -3.88]
#   Client 3: [  -4.98,   -5.76]
# True aggregate:   [0.0600, 0.0300]
# Masked aggregate: [0.0600, 0.0300]  — exact!
# Exact match: True

The real overhead of secure aggregation is O(K²) pairwise key exchanges per round. For cross-silo settings with K=10–100 clients, this is manageable (about 1.7x communication expansion). For cross-device with K=1000+ clients per round, optimized protocols like Google's production system reduce this through hierarchical aggregation.

6. Federated Learning Meets Differential Privacy

Secure aggregation hides individual updates from the server, but the final trained model itself still encodes information about individual training examples. Membership inference attacks can determine whether a specific data point was in the training set. Model inversion can reconstruct representative examples of a class. Secure aggregation protects the process; we need something that protects the outcome.

Differential privacy provides that mathematical guarantee. DP-FedAvg (McMahan et al. 2018) modifies the FedAvg loop in two critical ways:

  1. Clip each client's model update to have L2 norm at most S (bounding sensitivity — no single client can have outsized influence).
  2. Add Gaussian noise N(0, σ²S²I) to the aggregated update, calibrated to the clipping threshold and the desired (ε, δ)-DP guarantee.

A crucial advantage in federated learning is privacy amplification by subsampling. Each round randomly selects a fraction C of clients. With C=0.001 (1,000 of 1 million devices), the effective privacy cost per round is far smaller than the nominal ε. Over thousands of rounds, you can achieve ε < 10 while maintaining useful model quality.

An important distinction: user-level vs. example-level DP. Standard DP-SGD (from our differential privacy post) protects individual training examples. In federated learning, you typically want user-level DP — protecting all of a user's data, not just one example. This requires clipping the entire user update, which means more noise is needed but provides a fundamentally stronger guarantee.

def dp_fedavg(hospitals, rounds=30, local_epochs=5, lr=0.5,
              clip_norm=1.0, noise_scale=0.0):
    """DP-FedAvg: clipping + Gaussian noise for differential privacy."""
    w_global = np.zeros(2)
    history = []

    for t in range(rounds):
        clipped_updates = []
        for X_k, y_k in hospitals:
            w_local = w_global.copy()
            for _ in range(local_epochs):
                p = sigmoid(X_k @ w_local)
                grad = X_k.T @ (p - y_k) / len(y_k)
                w_local -= lr * grad

            # Clip update to bound sensitivity
            delta = w_local - w_global
            norm = np.linalg.norm(delta)
            if norm > clip_norm:
                delta = delta * clip_norm / norm
            clipped_updates.append(delta)

        # Average clipped updates, then add calibrated noise
        avg_delta = np.mean(clipped_updates, axis=0)
        if noise_scale > 0:
            noise = np.random.randn(2) * noise_scale * clip_norm / len(hospitals)
            avg_delta += noise

        w_global += avg_delta
        history.append(acc(X_all, y_all, w_global))

    return history

hist_no_dp = dp_fedavg(hospitals, noise_scale=0.0)   # no privacy
hist_mod_dp = dp_fedavg(hospitals, noise_scale=1.0)   # moderate
hist_hi_dp = dp_fedavg(hospitals, noise_scale=5.0)    # strong

print(f"No DP (epsilon=inf):  {hist_no_dp[-1]:.0%}")
print(f"Moderate (sigma=1.0): {hist_mod_dp[-1]:.0%}")
print(f"Strong (sigma=5.0):   {hist_hi_dp[-1]:.0%}")
# No DP (epsilon=inf):  84% — best accuracy, no privacy
# Moderate (sigma=1.0): 79% — small cost for meaningful privacy
# Strong (sigma=5.0):   65% — real cost for strong guarantees

This is the privacy-utility tradeoff in action. Stronger privacy (more noise) degrades accuracy. The art of production federated learning is navigating this tradeoff: using enough noise to provide meaningful privacy guarantees while preserving enough signal for the model to learn. Techniques like privacy amplification by subsampling, adaptive clipping, and careful learning rate scheduling help push the Pareto frontier outward.

7. Practical Federated Learning and Open Challenges

Federated learning isn't just theory. Google deployed it in 2017 for Gboard next-word prediction, training across millions of Android phones using FedAvg with secure aggregation and differential privacy. The model quality matched centralized training — without a single keystroke leaving any device. Apple uses federated learning for Siri suggestions, QuickType predictions, Hey Siri speaker recognition, and Health feature personalization.

In healthcare, NVIDIA FLARE enables hospitals to collaboratively train diagnostic models without sharing patient data. The MELLODDY project trained drug discovery models across ten pharmaceutical companies — competitors who would never share proprietary molecular data but could share model updates. In finance, banks train anti-fraud models collaboratively while keeping transaction data strictly siloed.

But significant challenges remain:

Federated learning sits at the intersection of distributed systems, cryptography, differential privacy, and machine learning optimization. It's the answer to a fundamental question: how do we learn from data we can't see? As privacy regulation tightens and data gravity increases, this won't be a niche technique — it will be the default way large-scale models are trained.

8. References & Further Reading

DadOps cross-references: Gradient Descent from Scratch (SGD fundamentals that FedAvg distributes), Neural Networks from Scratch (the models being federated), Optimizers from Scratch (local optimizer choices), Differential Privacy from Scratch (DP-SGD extended to federated settings), Information Theory from Scratch (communication bounds), Knowledge Distillation from Scratch (federated distillation), Continual Learning from Scratch (client drift parallels catastrophic forgetting).