← Back to Blog

Optimal Transport from Scratch: Moving Probability Mass at Minimum Cost

The Problem of Comparing Distributions

You have two piles of sand and want to reshape one into the other. There are infinitely many ways to move the grains — but only one way that minimizes the total work. This 200-year-old optimization problem, first posed by Gaspard Monge in 1781 for military logistics, turned out to be exactly what modern machine learning needed to fix generative models.

If you've worked with probability distributions, you've probably used KL divergence to measure how different two distributions are. KL divergence works great — until it doesn't. When two distributions have non-overlapping support (one distribution assigns zero probability where the other assigns positive probability), KL divergence explodes to infinity. That's not a minor edge case — it's exactly what happens during GAN training when the generator distribution hasn't yet learned to cover the real data distribution.

Jensen-Shannon divergence fixes the infinity problem by averaging the two distributions first, but it introduces a new one: it ignores geometry. A distribution shifted by 1 unit and a distribution shifted by 100 units both get similar JS divergence scores, because JS only cares whether probability mass overlaps — not how far apart the non-overlapping parts are.

The Wasserstein distance solves both problems. Instead of comparing distributions point-by-point, it asks: what's the minimum cost of turning one distribution into the other? It respects the ground metric (nearby distributions are similar), it's always finite (even with non-overlapping support), and it provides smooth gradients everywhere. These properties made it the foundation of Wasserstein GANs and, increasingly, of flow matching and diffusion models.

Let's build optimal transport from scratch — starting with a problem from 1781.

The Monge Problem — One-to-One Transport

In 1781, French mathematician Gaspard Monge posed the following problem for the French military: given a set of source locations (mines) and a set of target locations (fortifications), find the assignment of earth from mines to fortifications that minimizes the total transportation cost.

Formally, we want a transport map T: X → Y that pushes the source distribution μ to the target distribution ν, while minimizing the total cost:

minTi c(xi, T(xi))    subject to T pushing μ to ν

Here c(x, y) is the cost of moving one unit of mass from x to y — typically the squared Euclidean distance ||x - y||². The key constraint is that T is a function: each source point maps to exactly one target point. No mass splitting allowed.

For discrete distributions with equal numbers of points and equal masses, this reduces to the assignment problem: find the one-to-one pairing of source to target points that minimizes total cost. This is a well-studied combinatorial optimization problem with efficient algorithms.

But Monge's formulation has a critical limitation. What if one heavy source point is near two light target points? The optimal solution would be to split the mass — but Monge's formulation forbids it. Worse, if the source has fewer points than the target, there may be no feasible solution at all. We need a more flexible framework.

import numpy as np
from scipy.optimize import linear_sum_assignment

# Source and target points in 2D
source = np.array([[1, 2], [3, 1], [2, 4], [5, 3], [4, 5]])
target = np.array([[6, 2], [8, 4], [7, 1], [9, 5], [5, 6]])

# Cost matrix: squared Euclidean distance between all pairs
n = len(source)
cost_matrix = np.zeros((n, n))
for i in range(n):
    for j in range(n):
        cost_matrix[i, j] = np.sum((source[i] - target[j]) ** 2)

# Solve the assignment problem (Monge's optimal transport)
row_ind, col_ind = linear_sum_assignment(cost_matrix)

total_cost = cost_matrix[row_ind, col_ind].sum()
print(f"Optimal assignment cost: {total_cost:.2f}")
for i, j in zip(row_ind, col_ind):
    print(f"  Source {source[i]} -> Target {target[j]}"
          f"  (cost: {cost_matrix[i, j]:.1f})")

# Output:
# Optimal assignment cost: 89.00
# Source [1 2] -> Target [6 2]  (cost: 25.0)
# Source [3 1] -> Target [7 1]  (cost: 16.0)
# Source [2 4] -> Target [5 6]  (cost: 13.0)
# Source [5 3] -> Target [8 4]  (cost: 10.0)
# Source [4 5] -> Target [9 5]  (cost: 25.0)

The linear_sum_assignment function uses the Hungarian algorithm to solve this in O(n³) time. The result is a permutation — a one-to-one mapping from source to target — that minimizes total squared distance. Notice each source sends all its mass to exactly one target. For this equal-weight case, that's fine. But what if the weights are unequal?

The Kantorovich Relaxation — Allowing Mass to Split

It took 161 years for someone to fix Monge's limitation. In 1942, Soviet mathematician Leonid Kantorovich (who would later win the Nobel Prize in Economics) proposed a beautiful relaxation: instead of requiring a one-to-one transport map, allow a transport plan.

A transport plan is a matrix γ where γij represents the amount of mass shipped from source point i to target point j. The constraints are simple: the row sums must equal the source weights (all mass leaves each source) and the column sums must equal the target weights (each target receives the right amount). The set of all valid transport plans is called the transport polytope Π(μ, ν).

Kantorovich's insight was that this is a linear program:

minγi,j c(xi, yj) · γij    subject to γ ≥ 0, γ1 = a, γT1 = b

Unlike Monge's formulation, this always has a solution. And it gives us the Wasserstein distance: Wp(μ, ν) = (minγ ∑ cijp γij)1/p. This is a proper metric — it satisfies the triangle inequality, it's symmetric, and it equals zero only when the distributions are identical.

import numpy as np
from scipy.optimize import linprog

def kantorovich_transport(source_pts, target_pts, a, b):
    """Solve the Kantorovich optimal transport problem via LP.
    a = source weights (sum to 1), b = target weights (sum to 1)."""
    n, m = len(source_pts), len(target_pts)

    # Cost matrix (squared Euclidean distances)
    C = np.array([[np.sum((s - t) ** 2) for t in target_pts]
                   for s in source_pts])

    # Flatten cost for LP: min c^T x where x = gamma.ravel()
    c_vec = C.ravel()

    # Equality constraints: row sums = a, col sums = b
    # Row sums: for each source i, sum_j gamma_ij = a_i
    A_row = np.zeros((n, n * m))
    for i in range(n):
        A_row[i, i * m:(i + 1) * m] = 1.0

    # Col sums: for each target j, sum_i gamma_ij = b_j
    A_col = np.zeros((m, n * m))
    for j in range(m):
        for i in range(n):
            A_col[j, i * m + j] = 1.0

    A_eq = np.vstack([A_row, A_col])
    b_eq = np.concatenate([a, b])

    result = linprog(c_vec, A_eq=A_eq, b_eq=b_eq,
                     bounds=[(0, None)] * (n * m), method='highs')

    gamma = result.x.reshape(n, m)
    return gamma, result.fun

# Unequal masses: 3 sources, 4 targets (mass must split!)
source = np.array([[0, 0], [1, 0], [0.5, 1]])
target = np.array([[2, 0], [3, 0], [2, 1], [3, 1]])
a = np.array([0.5, 0.3, 0.2])        # source weights
b = np.array([0.2, 0.3, 0.2, 0.3])   # target weights

gamma, cost = kantorovich_transport(source, target, a, b)
print(f"Optimal transport cost: {cost:.4f}")
print(f"\nCoupling matrix gamma (rows=source, cols=target):")
print(np.round(gamma, 3))
# Non-zero off-diagonal entries show mass splitting

The coupling matrix γ is the heart of the Kantorovich solution. For the Monge problem (equal weights, same number of points), γ is a permutation matrix — exactly one non-zero entry per row and column. But for unequal weights, γ can have multiple non-zero entries per row, meaning one source splits its mass across several targets. This flexibility is what makes Kantorovich's formulation strictly more powerful than Monge's.

The Wasserstein Distance — Why It's Better Than KL

Now that we have the machinery to compute optimal transport, let's see why the resulting Wasserstein distance is such a game-changer for machine learning.

In one dimension, the Wasserstein-1 distance has an elegant closed form. For two distributions with CDFs F and G:

W1(F, G) = ∫ |F(x) - G(x)| dx

That's it — just the area between the two CDFs. No optimization needed. For discrete samples, you sort both sets of points, pair them up in order, and sum the absolute differences. This simplicity makes 1D Wasserstein extremely fast to compute and deeply intuitive: it's literally the amount of "earth" you need to move to reshape one CDF into the other.

The real power of Wasserstein shows up when you compare it to KL divergence. Consider two unit Gaussians that start overlapping and gradually drift apart:

import numpy as np

def wasserstein_1d(samples_p, samples_q):
    """W1 distance between 1D samples via CDF area."""
    all_pts = np.sort(np.unique(np.concatenate([samples_p, samples_q])))
    # Empirical CDFs at each point
    cdf_p = np.searchsorted(np.sort(samples_p), all_pts, side='right') / len(samples_p)
    cdf_q = np.searchsorted(np.sort(samples_q), all_pts, side='right') / len(samples_q)
    # Trapezoidal integration of |F - G|
    diffs = np.abs(cdf_p - cdf_q)
    dx = np.diff(all_pts, prepend=all_pts[0] - 0.5, append=all_pts[-1] + 0.5)
    return np.sum(diffs * dx[1:])

def kl_divergence(p_hist, q_hist):
    """KL(P || Q) from histograms. Returns inf if Q=0 where P>0."""
    mask = p_hist > 0
    if np.any((q_hist[mask] == 0)):
        return float('inf')
    return np.sum(p_hist[mask] * np.log(p_hist[mask] / q_hist[mask]))

def js_divergence(p_hist, q_hist):
    """Jensen-Shannon divergence from histograms."""
    m = 0.5 * (p_hist + q_hist)
    return 0.5 * kl_divergence(p_hist, m) + 0.5 * kl_divergence(q_hist, m)

# Compare as two Gaussians drift apart
np.random.seed(7)
n = 5000
bins = np.linspace(-8, 16, 200)

print(f"{'Shift':>6}  {'W1':>8}  {'KL':>10}  {'JS':>8}")
print("-" * 38)
for shift in [0.0, 0.5, 1.0, 2.0, 4.0, 8.0]:
    p = np.random.normal(0, 1, n)
    q = np.random.normal(shift, 1, n)
    w1 = wasserstein_1d(p, q)
    p_hist, _ = np.histogram(p, bins=bins, density=True)
    q_hist, _ = np.histogram(q, bins=bins, density=True)
    p_hist = p_hist / p_hist.sum() + 1e-10
    q_hist = q_hist / q_hist.sum() + 1e-10
    kl = kl_divergence(p_hist, q_hist)
    js = js_divergence(p_hist, q_hist)
    print(f"{shift:6.1f}  {w1:8.3f}  {kl:10.3f}  {js:8.4f}")

# Output:
# Shift       W1          KL        JS
# --------------------------------------
#    0.0     0.027       0.022    0.0024
#    0.5     0.509       0.150    0.0349
#    1.0     0.988       0.614    0.1119
#    2.0     1.983       2.758    0.3358
#    4.0     3.985      13.648    0.6332
#    8.0     7.995      19.498    0.6930

Look at the pattern: W1 grows linearly with the shift — exactly what you'd want. Move the distributions twice as far apart, and the distance doubles. KL divergence grows quadratically and would explode to infinity if the supports didn't overlap at all. JS divergence is bounded (by ln 2 ≈ 0.693), which sounds nice until you realize that means it saturates — once distributions are far enough apart, JS can't tell the difference between "a little separated" and "vastly separated." That's terrible for optimization because it means the gradients vanish.

This is exactly the insight behind Wasserstein GANs. The Kantorovich-Rubinstein duality tells us that W1(μ, ν) = sup||f||L≤1 Eμ[f(X)] - Eν[f(X)], where the supremum is over all 1-Lipschitz functions. In practice, you train a neural network critic with a gradient penalty to approximate this 1-Lipschitz function, giving you a smooth loss signal that guides the generator at every training step.

Sinkhorn's Algorithm — Fast Optimal Transport via Entropy

The Kantorovich LP gives us exact optimal transport, but it's slow. Interior-point methods run in O(n³ log n) time, which is fine for 10 points but impractical for the thousands or millions of samples common in machine learning. We need something faster.

In 2013, Marco Cuturi published a paper that changed the field: add entropic regularization to the transport problem. Instead of finding the plan that minimizes cost alone, find the plan that minimizes cost plus a penalty for low-entropy plans:

minγ∈Πi,j Cij γij - ε H(γ)    where H(γ) = -∑i,j γij log γij

The entropy term H(γ) is maximized when the transport plan spreads mass evenly — so the regularization pushes the solution toward "smoother" plans. The parameter ε controls the tradeoff: small ε gives a plan close to the exact LP solution (sparse, concentrated), large ε gives a diffuse plan that spreads mass everywhere.

The magic is in how this changes the solution. The optimal regularized plan has the form γ* = diag(u) · K · diag(v), where Kij = exp(-Cij/ε) is a Gibbs kernel. Finding u and v is simple: alternately normalize the rows and columns of K. That's it — that's Sinkhorn's algorithm.

import numpy as np

def sinkhorn(cost_matrix, a, b, epsilon=0.1, max_iter=100, tol=1e-8):
    """Sinkhorn's algorithm for entropy-regularized optimal transport.
    cost_matrix: (n, m) pairwise costs
    a: (n,) source weights   b: (m,) target weights
    epsilon: regularization strength (smaller = closer to exact OT)
    Returns: transport plan gamma, list of iteration costs."""
    n, m = cost_matrix.shape
    K = np.exp(-cost_matrix / epsilon)   # Gibbs kernel

    u = np.ones(n)           # row scaling factors
    v = np.ones(m)           # column scaling factors
    costs = []

    for iteration in range(max_iter):
        u_prev = u.copy()
        u = a / (K @ v)          # row normalization
        v = b / (K.T @ u)        # column normalization

        # Transport plan and cost for monitoring
        gamma = np.diag(u) @ K @ np.diag(v)
        transport_cost = np.sum(gamma * cost_matrix)
        costs.append(transport_cost)

        # Check convergence
        if np.max(np.abs(u - u_prev)) < tol:
            break

    return gamma, costs

# Example: 5 source points, 5 target points
np.random.seed(42)
source = np.random.rand(5, 2)
target = np.random.rand(5, 2) + np.array([0.5, 0])
C = np.array([[np.sum((s - t) ** 2) for t in target] for s in source])
a = np.ones(5) / 5
b = np.ones(5) / 5

# Compare sharp vs diffuse plans
for eps in [0.01, 0.1, 1.0]:
    gamma, costs = sinkhorn(C, a, b, epsilon=eps)
    max_g = gamma.max()
    print(f"epsilon={eps:.2f}: cost={costs[-1]:.3f}, "
          f"converged in {len(costs)} iters")
    print(f"  Plan sparsity: {np.sum(gamma < max_g * 0.01)}/{gamma.size} near-zero entries")

# Output:
# epsilon=0.01: cost=0.329, converged in 100 iters
#   Plan sparsity: 16/25 near-zero entries
# epsilon=0.10: cost=0.374, converged in 49 iters
#   Plan sparsity: 4/25 near-zero entries
# epsilon=1.00: cost=0.529, converged in 6 iters
#   Plan sparsity: 0/25 near-zero entries

Each Sinkhorn iteration is just a matrix-vector multiply — O(n²) per iteration, and it typically converges in 20–100 iterations. Compare that to O(n³ log n) for the exact LP. Even better, matrix-vector multiplies are trivially parallelizable on GPUs, making Sinkhorn the go-to algorithm for large-scale optimal transport.

The ε parameter creates a beautiful tradeoff visible in the output above. At ε = 0.01, the plan is nearly exact (sparse, 16 of 25 entries are negligible — close to a permutation matrix). At ε = 1.0, the plan is fully diffuse (every source sends mass to every target), and the transport cost is slightly higher. The sweet spot is problem-dependent, but ε ≈ 0.1 is a common starting point.

One subtlety: the raw Sinkhorn cost is biased — it doesn't equal zero when comparing a distribution to itself. The fix is the Sinkhorn divergence: Sε(μ, ν) = OTε(μ, ν) - ½ OTε(μ, μ) - ½ OTε(ν, ν). This corrected version satisfies S(μ, μ) = 0 and interpolates between Wasserstein distance (ε → 0) and maximum mean discrepancy (ε → ∞).

Displacement Interpolation — Morphing Between Distributions

Optimal transport doesn't just tell you the cost of transforming one distribution into another — it tells you how to do the transformation, and this unlocks something beautiful: smooth interpolation between distributions.

Given source distribution μ and target ν with optimal transport map T, the McCann interpolation defines a path through distribution space:

μt = ((1-t) · id + t · T)#μ    for t ∈ [0, 1]

At t=0 we have the source, at t=1 we have the target, and in between, each point travels along a straight line from its source position to its assigned target position. This is a geodesic in Wasserstein space — the shortest path between two distributions.

A natural extension is the Wasserstein barycenter: given multiple distributions, find the one that minimizes the sum of Wasserstein distances to all of them. Think of it as the "average" distribution, but one that respects geometry. Iterative Bregman projections — alternating Sinkhorn computations for each target distribution — provide an efficient algorithm.

import numpy as np

def wasserstein_barycenter(histograms, support, weights=None,
                           epsilon=0.05, max_iter=50):
    """Compute Wasserstein barycenter of 1D histograms.
    Uses iterative Bregman projections (Cuturi & Doucet 2014).
    histograms: list of K histograms (each sums to 1)
    support: shared bin centers
    weights: barycentric weights (uniform if None)"""
    K_dists = len(histograms)
    n_bins = len(support)
    if weights is None:
        weights = np.ones(K_dists) / K_dists

    # Cost matrix (normalized to prevent kernel underflow)
    C = (support[:, None] - support[None, :]) ** 2
    C /= C.max()
    K = np.exp(-C / epsilon)  # Gibbs kernel

    # Initialize barycenter as uniform
    bary = np.ones(n_bins) / n_bins

    for iteration in range(max_iter):
        log_bary = np.zeros(n_bins)
        for k in range(K_dists):
            # Sinkhorn scaling for transport from bary to histogram k
            u = np.ones(n_bins)
            for _ in range(100):
                v = histograms[k] / (K.T @ u)
                u = bary / (K @ v)
            # K @ v transports histogram k into barycenter's frame
            log_bary += weights[k] * np.log(np.maximum(K @ v, 1e-16))
        bary = np.exp(log_bary)
        bary /= bary.sum()

    return bary

# Three 1D distributions: left peak, center peak, right peak
bins = np.linspace(0, 10, 51)
h1 = np.exp(-0.5 * ((bins - 2) / 0.8) ** 2)
h2 = np.exp(-0.5 * ((bins - 5) / 1.0) ** 2)
h3 = np.exp(-0.5 * ((bins - 8) / 0.6) ** 2)
for h in [h1, h2, h3]:
    h /= h.sum()

bary = wasserstein_barycenter([h1, h2, h3], bins, epsilon=0.05)
peak_pos = bins[np.argmax(bary)]
print(f"Barycenter peak at x={peak_pos:.1f} (mean of 2, 5, 8 = 5.0)")
# Output: Barycenter peak at x=5.0 (mean of 2, 5, 8 = 5.0)

The Wasserstein barycenter sits at x=5.0 — the geometric center of the three input peaks. But unlike a simple pointwise average (which would create a blurry trimodal distribution), the barycenter is unimodal with a shape that reflects the shapes of the inputs. This "shape-preserving averaging" is uniquely valuable in applications like image interpolation, texture synthesis, and shape morphing.

The connection to flow matching is deep: the forward process in flow matching maps noise to data along conditional paths, and the optimal conditional paths are exactly McCann interpolations. Rectified flows (Liu et al. 2023) iteratively straighten curved flow trajectories, converging to the OT map after enough "reflow" iterations. Optimal transport provides the mathematical foundation for why straight flows work so well.

Optimal Transport in Modern ML

Optimal transport has become a Swiss Army knife in modern machine learning. Here are the key applications, starting with the one that brought OT into the ML mainstream.

Wasserstein GANs

Arjovsky et al. (2017) replaced the Jensen-Shannon divergence in the GAN objective with the Wasserstein-1 distance, using the Kantorovich-Rubinstein dual form. The critic (discriminator) approximates a 1-Lipschitz function, and a gradient penalty enforces the Lipschitz constraint. Our GANs post covers the implementation in detail — here's the core loss computation:

import numpy as np

def wgan_critic_loss(critic_fn, real_data, fake_data, lambda_gp=10.0):
    """WGAN-GP critic loss with gradient penalty.
    critic_fn(x): maps data points to scalar scores.
    Critic wants: high scores for real, low for fake."""
    # Wasserstein estimate: E[critic(real)] - E[critic(fake)]
    real_scores = np.array([critic_fn(x) for x in real_data])
    fake_scores = np.array([critic_fn(x) for x in fake_data])
    wasserstein_est = real_scores.mean() - fake_scores.mean()

    # Gradient penalty: enforce ||grad(critic)|| ā‰ˆ 1
    # Interpolate between real and fake samples
    batch_size = min(len(real_data), len(fake_data))
    alpha = np.random.rand(batch_size, 1)
    interpolated = alpha * real_data[:batch_size] + \
                   (1 - alpha) * fake_data[:batch_size]

    # Approximate gradient via finite differences
    eps = 1e-4
    grad_norms = []
    for x in interpolated:
        grads = []
        for d in range(x.shape[0]):
            x_plus = x.copy(); x_plus[d] += eps
            x_minus = x.copy(); x_minus[d] -= eps
            grads.append((critic_fn(x_plus) - critic_fn(x_minus)) / (2 * eps))
        grad_norms.append(np.sqrt(sum(g ** 2 for g in grads)))
    grad_penalty = np.mean([(gn - 1) ** 2 for gn in grad_norms])

    # Critic loss: minimize -wasserstein + penalty
    loss = -wasserstein_est + lambda_gp * grad_penalty
    return loss, wasserstein_est, grad_penalty

# Demo with a simple linear critic
def simple_critic(x):
    return 0.8 * x[0] + 0.6 * x[1]

np.random.seed(42)
real = np.random.normal([3, 3], 0.5, (32, 2))
fake = np.random.normal([0, 0], 0.5, (32, 2))

loss, w_est, gp = wgan_critic_loss(simple_critic, real, fake)
print(f"Wasserstein estimate: {w_est:.3f}")
print(f"Gradient penalty:     {gp:.3f}")
print(f"Total critic loss:    {loss:.3f}")
# Wasserstein estimate: 4.030
# Gradient penalty:     0.000
# Total critic loss:    -4.030

Flow Matching with OT Paths

Lipman et al. (2023) showed that conditional flow matching with optimal transport couplings produces straighter trajectories than random pairings. In mini-batch OT-CFM, you solve a small OT problem within each training batch to pair noise samples with data samples, then train the flow to follow these OT-optimal paths. Our flow matching post covers this in Section 5 — optimal transport provides the mathematical reason why OT paths need fewer integration steps at inference time.

Domain Adaptation

Courty et al. (2017) used optimal transport to align source and target domain distributions for transfer learning. The idea is elegant: compute the transport plan between source and target features, then "transport" source labels along the plan to label the target domain. This OT-based domain adaptation naturally respects the geometry of the feature space.

Computational Biology

Single-cell RNA sequencing generates snapshots of cell populations at different time points. Optimal transport reconstructs the most likely trajectories cells followed between snapshots (Schiebinger et al. 2019). The Waddington-OT framework treats each time-point as a distribution over gene-expression space and computes the optimal transport plan to infer cell fate decisions.

References & Further Reading

Try It: Transport Plan Visualizer

Watch how entropic regularization changes the transport plan. Small ε gives a sparse plan (each source sends mass to one target). Large ε spreads mass everywhere.

Click "Solve" to compute the optimal transport plan.

Try It: Wasserstein vs KL Explorer

Drag the histogram bars to change the distributions. Watch how W1, KL, and JS divergence respond — especially when distributions become disjoint (KL → ∞ but W1 stays finite).

W₁ = 0.00 KL(P‖Q) = 0.00 JS = 0.00

Conclusion

Optimal transport gives us a principled way to compare, transform, and interpolate between probability distributions by finding the minimum-cost plan to move mass from one to the other. Starting from Monge's assignment problem (1781) and Kantorovich's LP relaxation (1942), we built up to Sinkhorn's entropy-regularized algorithm (2013) that made OT practical for machine learning.

The key ideas to remember:

From Wasserstein GANs to flow matching to computational biology, optimal transport has become one of the most versatile mathematical frameworks in modern ML. The core algorithm is remarkably simple — Sinkhorn is just alternating row and column normalization of a matrix — but the theoretical depth runs all the way to Riemannian geometry and measure theory. That combination of practical simplicity and mathematical beauty is what makes optimal transport such a satisfying topic to learn from scratch.