← Back to Blog

Differentiable Programming from Scratch: Making Everything Gradient-Friendly

1. What Is Differentiable Programming?

You already know backpropagation. You've traced gradients through a neural network, watched a loss curve descend, and marveled at how the chain rule makes it all work. But here's the question that opens up a whole new world: what if you could differentiate through any program, not just a neural network?

Think about it. A physics simulator is just a function: give it initial conditions, it returns a trajectory. A sorting algorithm is a function: give it scores, it returns a permutation. A 3D renderer is a function: give it a scene description, it returns an image. If you could compute gradients through these programs, you could optimize anything — find the launch angle that hits a target, learn a ranking that maximizes clicks, or recover a 3D scene from 2D photos.

That's differentiable programming: the paradigm where you write arbitrary computational programs and then optimize them using gradients. Yann LeCun called it "a little bit more than deep learning" — the natural generalization where neural networks are just one special case of differentiable computation.

Our backpropagation post covered how reverse-mode automatic differentiation works for neural networks. This post goes further. We'll build two complete autodiff engines (forward-mode and reverse-mode), learn to differentiate through operations that seem non-differentiable (sorting, argmax, hard thresholds), differentiate through the solution of an equation without unrolling the solver, push gradients through a physics simulator, and survey how modern frameworks like JAX make all of this composable.

If you can write it as code, you can differentiate through it. And if you can differentiate through it, you can optimize it.

2. Forward-Mode AD — Dual Numbers

Here's an elegant mathematical trick. Extend the real numbers by inventing a new element ε with one special property: ε² = 0, but ε ≠ 0. A "dual number" has the form a + bε, where a is the real part and b is the derivative part.

The arithmetic rules follow directly from ε² = 0:

(a + bε) + (c + dε) = (a+c) + (b+d)ε
(a + bε) × (c + dε) = ac + (ad + bc)ε

The multiplication rule is the key insight. The bdε² term vanishes because ε² = 0, leaving exactly the product rule from calculus: d(ac)/dx = ad + bc when b = da/dx and d = dc/dx.

Now watch what happens when you evaluate any function at a + 1·ε. By Taylor expansion:

f(a + ε) = f(a) + f'(a)ε + f''(a)ε²/2 + … = f(a) + f'(a)ε

All the ε² and higher terms vanish. The derivative rides along in the ε component for free. You evaluate the function once, and you get both the value and its exact derivative — no finite differences, no symbolic manipulation, no computational graph.

The chain rule emerges automatically: f(g(a + ε)) = f(g(a) + g'(a)ε) = f(g(a)) + f'(g(a))·g'(a)ε. Composition just works.

Let's implement this. Our Dual class overloads Python's arithmetic operators so dual numbers flow through any computation transparently:

import math

class Dual:
    """A dual number a + b*epsilon for forward-mode autodiff."""
    def __init__(self, val, deriv=0.0):
        self.val = val      # function value
        self.deriv = deriv   # derivative value

    def __add__(self, other):
        o = other if isinstance(other, Dual) else Dual(other)
        return Dual(self.val + o.val, self.deriv + o.deriv)

    def __radd__(self, other):
        return Dual(other).__add__(self)

    def __mul__(self, other):
        o = other if isinstance(other, Dual) else Dual(other)
        return Dual(self.val * o.val, self.val * o.deriv + self.deriv * o.val)

    def __rmul__(self, other):
        return Dual(other).__mul__(self)

    def __pow__(self, n):
        return Dual(self.val ** n, n * self.val ** (n - 1) * self.deriv)

    def __neg__(self):
        return Dual(-self.val, -self.deriv)

    def __sub__(self, other):
        return self + (-other)

    def __rsub__(self, other):
        return Dual(other) + (-self)

def sin(x):
    if isinstance(x, Dual):
        return Dual(math.sin(x.val), math.cos(x.val) * x.deriv)
    return math.sin(x)

def exp(x):
    if isinstance(x, Dual):
        e = math.exp(x.val)
        return Dual(e, e * x.deriv)
    return math.exp(x)

# Evaluate f(x) = x**3 + sin(x) and its derivative at x = 2.0
x = Dual(2.0, 1.0)  # seed derivative = 1 means df/dx
result = x ** 3 + sin(x)
print(f"f(2.0)  = {result.val:.6f}")   # 8 + sin(2) = 8.909297
print(f"f'(2.0) = {result.deriv:.6f}") # 3*4 + cos(2) = 11.583853

That's it — exact derivatives of arbitrary expressions in a single forward pass. Forward-mode computes one directional derivative per pass, building the Jacobian one column at a time. This is efficient when you have few inputs and many outputs (e.g., computing the full Jacobian of a function from R3 to R1000).

3. Reverse-Mode AD — A Tiny Autograd Engine

Forward-mode shines with few inputs. But in deep learning, we typically have millions of inputs (parameters) and one output (scalar loss). Computing the gradient of a scalar with respect to a million parameters would require a million forward-mode passes — one per parameter. That's where reverse-mode saves the day: a single backward pass gives you all the gradients at once.

The idea: during the forward pass, record every operation into a computational graph. Then traverse the graph backwards, accumulating gradients from the output to every input. Our backpropagation post showed this for neural network layers. Here we build a general-purpose engine that can differentiate any composition of supported operations — not just layers of matrix multiplies.

class Value:
    """A node in a computational graph for reverse-mode autodiff."""
    def __init__(self, data, children=(), op=''):
        self.data = float(data)
        self.grad = 0.0
        self._backward = lambda: None
        self._children = set(children)
        self._op = op

    def __add__(self, other):
        other = other if isinstance(other, Value) else Value(other)
        out = Value(self.data + other.data, (self, other), '+')
        def _backward():
            self.grad += out.grad
            other.grad += out.grad
        out._backward = _backward
        return out

    def __mul__(self, other):
        other = other if isinstance(other, Value) else Value(other)
        out = Value(self.data * other.data, (self, other), '*')
        def _backward():
            self.grad += other.data * out.grad
            other.grad += self.data * out.grad
        out._backward = _backward
        return out

    def __pow__(self, n):
        out = Value(self.data ** n, (self,), f'**{n}')
        def _backward():
            self.grad += n * self.data ** (n - 1) * out.grad
        out._backward = _backward
        return out

    def __neg__(self):
        return self * -1

    def __radd__(self, other):
        return self + other

    def __rmul__(self, other):
        return self * other

    def __sub__(self, other):
        return self + (-other)

    def sin(self):
        import math
        out = Value(math.sin(self.data), (self,), 'sin')
        def _backward():
            self.grad += math.cos(self.data) * out.grad
        out._backward = _backward
        return out

    def exp(self):
        import math
        e = math.exp(self.data)
        out = Value(e, (self,), 'exp')
        def _backward():
            self.grad += e * out.grad
        out._backward = _backward
        return out

    def backward(self):
        order = []
        visited = set()
        def topo(v):
            if v not in visited:
                visited.add(v)
                for c in v._children:
                    topo(c)
                order.append(v)
        topo(self)
        self.grad = 1.0
        for v in reversed(order):
            v._backward()

# Differentiate a non-neural-network function:
# f(x, y) = sin(x * y) + exp(x^2)
x = Value(1.0)
y = Value(2.0)
f = (x * y).sin() + (x ** 2).exp()

f.backward()
print(f"f(1,2)   = {f.data:.6f}")   # sin(2) + exp(1) = 3.627579
print(f"df/dx    = {x.grad:.6f}")   # y*cos(xy) + 2x*exp(x^2) = 4.604270
print(f"df/dy    = {y.grad:.6f}")   # x*cos(xy) = cos(2) = -0.416147

The engine builds a directed acyclic graph during the forward pass. Each node stores a _backward closure that implements the local chain rule for its operation. Calling backward() performs a topological sort and then visits nodes in reverse order, accumulating gradients via the += pattern (this handles fan-out correctly — when a value is used multiple times, its gradient is the sum of all contributions).

JVP vs VJP: When to Use Which

Forward-mode computes a Jacobian-vector product (JVP): given a tangent vector v, it computes J · v in one pass, building the Jacobian one column at a time. Reverse-mode computes a vector-Jacobian product (VJP): given a cotangent vector v, it computes vT · J in one pass, building the Jacobian one row at a time.

ModeComputesCostBest When
Forward (JVP)J · vO(ninputs) passes for full JFew inputs, many outputs
Reverse (VJP)vT · JO(noutputs) passes for full JMany inputs, few outputs

For deep learning (millions of parameters, scalar loss), reverse-mode wins overwhelmingly — one backward pass gives you everything. For control problems or sensitivity analysis (few parameters, vector-valued output), forward-mode is more efficient.

Try It: Autodiff Mode Comparator

Watch how forward-mode and reverse-mode traverse the same expression tree f(x,y) = sin(x·y) + x². Click "Step" to advance one propagation step at a time.

Step 0 / 7
Forward: — Reverse: —

4. Making the Non-Differentiable Differentiable

So far, everything we've differentiated has been smooth — polynomials, sines, exponentials. But many useful operations are decidedly not smooth. The argmax function returns an integer index. Sorting produces a permutation. A hard threshold outputs 0 or 1 with no in-between. The gradient of these operations is zero almost everywhere, which means gradient-based optimization gets stuck.

The breakthrough insight: you don't need the true gradient — you just need a useful gradient signal that points optimization in the right direction. Three techniques make this work:

The Straight-Through Estimator (STE)

Proposed by Bengio et al. (2013), the STE is beautifully simple. During the forward pass, use the hard non-differentiable operation as-is. During the backward pass, pretend the operation was the identity function and pass the gradient straight through:

Forward: y = hard_threshold(x)
Backward: ∂L/∂x ≈ ∂L/∂y × 1

It's a lie, but a useful one. The gradient signal is biased, but it points in roughly the right direction. This technique powers binary neural networks, VQ-VAE's vector quantization, and discrete action selection in reinforcement learning.

Soft Relaxations with Temperature

A more principled approach: replace the hard operation with a smooth approximation controlled by a temperature parameter τ. The softmax function is the classic example:

soft_argmax(x, τ) = softmax(x / τ)

As τ → 0, the softmax sharpens toward a one-hot vector (hard argmax). As τ → ∞, it becomes uniform. At moderate τ, the gradients flow smoothly. You get a knob that trades off between faithfully approximating the discrete operation (low τ) and having well-behaved gradients (high τ).

import numpy as np

def hard_argmax(scores):
    """Non-differentiable: returns one-hot vector."""
    idx = np.argmax(scores)
    one_hot = np.zeros_like(scores)
    one_hot[idx] = 1.0
    return one_hot  # gradient is zero everywhere

def soft_argmax(scores, tau=1.0):
    """Differentiable relaxation with temperature."""
    shifted = scores - np.max(scores)  # numerical stability
    exps = np.exp(shifted / tau)
    return exps / np.sum(exps)  # smooth, gradient flows

scores = np.array([2.0, 5.0, 1.0, 3.0])

print("Hard argmax:", hard_argmax(scores))
# [0. 1. 0. 0.] -- no gradient information

print("Soft (tau=1.0):", np.round(soft_argmax(scores, 1.0), 3))
# [0.041 0.831 0.015 0.112] -- smooth, peaked at index 1

print("Soft (tau=0.1):", np.round(soft_argmax(scores, 0.1), 3))
# [0.000 1.000 0.000 0.000] -- nearly hard, but still differentiable

print("Soft (tau=5.0):", np.round(soft_argmax(scores, 5.0), 3))
# [0.206 0.375 0.168 0.251] -- very smooth, gradients everywhere

The Gumbel-Softmax trick (Jang et al. 2017) extends this to differentiable sampling from categorical distributions by adding Gumbel noise before the softmax. NeuralSort (Grover et al. 2019) applies the same temperature-controlled relaxation to the sorting operator, producing differentiable permutation matrices. The pattern is always the same: replace the hard operation with a smooth parameterized family that converges to the hard operation as a temperature approaches zero.

Try It: Straight-Through Estimator Playground

Optimizing scores to select a target class. Hard argmax gets stuck (zero gradient). STE and soft argmax find the solution. Drag the temperature slider to see how τ affects convergence.

2.0
Red = Hard argmax (stuck) Blue = STE Green = Soft argmax(τ)

5. Implicit Differentiation — Differentiating Through Solvers

Here's a scenario that seems impossible at first. Suppose you have an equation g(p, x) = 0 that implicitly defines x* as a function of p. You solve this equation with some iterative method (Newton's method, bisection, fixed-point iteration) to find x*. Now you want dx*/dp. Do you need to differentiate through every iteration of the solver? That could be hundreds of steps, each storing activations for the backward pass.

The Implicit Function Theorem says no. If the Jacobian ∂g/∂x* is invertible at the solution, then:

dx*/dp = −(∂g/∂x*)−1 · (∂g/∂p)

You only need the partial derivatives of g evaluated at the solution point. It doesn't matter whether you found that solution in 3 iterations or 3000 — the gradient is the same, and it costs constant memory to compute.

Let's see this in action. We'll find the cube root of a parameter p by solving g(p, x) = x³ − p = 0, then differentiate the solution with respect to p:

import numpy as np

def solve_cube_root(p, tol=1e-12):
    """Find x such that x^3 - p = 0, using Newton's method."""
    x = float(p) ** (1/3) if p > 0 else -(abs(float(p)) ** (1/3))
    for _ in range(100):
        g = x ** 3 - float(p)
        dg_dx = 3 * x ** 2
        if abs(dg_dx) < 1e-15:
            break
        x = x - g / dg_dx
        if abs(g) < tol:
            break
    return x

def implicit_grad(p):
    """Compute dx*/dp via the Implicit Function Theorem.
    g(p, x) = x^3 - p = 0
    dg/dx = 3x^2,  dg/dp = -1
    dx*/dp = -(dg/dx)^{-1} * (dg/dp) = -1/(3x^2) * (-1) = 1/(3x^2)
    """
    x_star = solve_cube_root(p)
    dg_dx = 3 * x_star ** 2    # partial of g w.r.t. x at solution
    dg_dp = -1.0                # partial of g w.r.t. p at solution
    return x_star, -dg_dp / dg_dx

# Test: d(p^{1/3})/dp = (1/3) * p^{-2/3}
for p in [1.0, 8.0, 27.0]:
    x_star, grad = implicit_grad(p)
    analytic = (1/3) * p ** (-2/3)
    print(f"p={p:5.1f}  x*={x_star:.6f}  "
          f"IFT grad={grad:.6f}  analytic={analytic:.6f}")
# p=  1.0  x*=1.000000  IFT grad=0.333333  analytic=0.333333
# p=  8.0  x*=2.000000  IFT grad=0.083333  analytic=0.083333
# p= 27.0  x*=3.000000  IFT grad=0.037037  analytic=0.037037

The gradients match exactly, computed without differentiating through any solver iterations. This principle powers two major innovations:

If you've read our neural ODEs post, you'll recognize the same principle at work: the adjoint method differentiates through an ODE solution without storing intermediate states. Implicit differentiation is the algebraic cousin of the continuous adjoint — same theorem, different domains.

6. Differentiable Physics and Rendering

Once you internalize the pattern — replace hard operations with smooth approximations, or use implicit differentiation to skip the forward computation entirely — you start seeing opportunities everywhere. Two domains have been transformed by this thinking: physics simulation and 3D rendering.

Differentiable Rendering (NeRF)

NeRF (Mildenhall et al. 2020) represents a 3D scene as a neural network that maps a 3D coordinate and viewing direction to a color and density: F(x, y, z, θ, φ) → (r, g, b, σ). To render a pixel, you cast a ray through the scene and integrate along it using the volume rendering equation:

C(ray) = ∑i Ti · (1 − e−σiδi) · ci

where Ti = exp(−∑j<i σjδj) is the accumulated transmittance. This equation is a differentiable weighted sum — gradients flow from pixel colors all the way back to the neural network weights. The key insight: volume rendering is inherently differentiable, unlike rasterization which has hard edges and occlusion boundaries.

Differentiable Physics

A physics simulator is a sequence of state updates: st+1 = step(st, at). If each step is differentiable, gradients chain through the entire trajectory via backpropagation through time. You can then optimize actions (or physical parameters) to achieve desired outcomes.

The catch: contact and collision physics involve discontinuities (an object either hits the ground or it doesn't). Differentiable physics engines smooth these out with soft contact models — replacing hard collisions with penalty forces that increase rapidly as objects approach. Frameworks like Brax (Google), DiffTaichi, and gradSim implement this approach for trajectory optimization, robot learning, and system identification.

Let's build a minimal example: a differentiable 1D projectile simulator where we optimize the launch angle to hit a target:

import math

class Value:
    """Minimal autograd (same engine from Section 3)."""
    def __init__(self, data, children=()):
        self.data = float(data)
        self.grad = 0.0
        self._backward = lambda: None
        self._children = set(children)
    def __add__(self, o):
        o = o if isinstance(o, Value) else Value(o)
        out = Value(self.data + o.data, (self, o))
        def _bw():
            self.grad += out.grad; o.grad += out.grad
        out._backward = _bw; return out
    def __mul__(self, o):
        o = o if isinstance(o, Value) else Value(o)
        out = Value(self.data * o.data, (self, o))
        def _bw():
            self.grad += o.data * out.grad; o.grad += self.data * out.grad
        out._backward = _bw; return out
    def __neg__(self): return self * -1
    def __sub__(self, o): return self + (-o)
    def __radd__(self, o): return self + o
    def __rmul__(self, o): return self * o
    def backward(self):
        order, visited = [], set()
        def topo(v):
            if v not in visited:
                visited.add(v)
                for c in v._children: topo(c)
                order.append(v)
        topo(self)
        self.grad = 1.0
        for v in reversed(order): v._backward()

def diff_sin(x):
    out = Value(math.sin(x.data), (x,))
    def _bw(): x.grad += math.cos(x.data) * out.grad
    out._backward = _bw; return out

def diff_cos(x):
    out = Value(math.cos(x.data), (x,))
    def _bw(): x.grad += -math.sin(x.data) * out.grad
    out._backward = _bw; return out

# Differentiable projectile: optimize angle to hit target
target_x = 10.0
v0 = 15.0       # launch speed (fixed)
g = 9.81         # gravity
dt = 0.01        # time step

angle = Value(0.5)  # initial guess: ~28.6 degrees

for step in range(80):
    # Reset gradients
    angle.grad = 0.0

    # Simulate trajectory with Euler integration
    vx = v0 * diff_cos(angle)
    vy = v0 * diff_sin(angle)
    px = Value(0.0)
    py = Value(0.0)

    for t in range(200):
        px = px + vx * dt
        py = py + vy * dt
        vy = vy + Value(-g) * dt
        if py.data < 0 and t > 5:
            break

    # Loss: squared distance from target
    loss = (px - target_x) * (px - target_x)
    loss.backward()

    # Gradient descent on the angle
    angle = Value(angle.data - 0.001 * angle.grad)

    if step % 20 == 0:
        print(f"step {step:3d}  angle={math.degrees(angle.data):6.2f} deg  "
              f"landing={px.data:6.2f}  loss={loss.data:.4f}")

Gradients flow backwards through every Euler step of the simulation, through the trigonometric functions, all the way to the angle parameter. The optimizer adjusts the angle until the projectile lands on the target. This is differentiable programming in action: we wrote an ordinary physics simulation, and because every operation is differentiable, we can optimize through it.

7. Modern Frameworks

Building a tiny autograd engine is educational, but production systems need to handle tensors, GPU acceleration, control flow, and composition of transforms. Three fundamentally different approaches have emerged:

Approach Framework How It Works Control Flow
Operator overloading PyTorch autograd Record ops to a tape as they execute; replay tape backwards Full Python (dynamic graph rebuilt each call)
Tracing + transformation JAX Trace function with abstract values → IR (jaxpr) → transform → compile with XLA Requires functional style; control flow via lax.cond, lax.scan
Source-to-source Zygote.jl Transform compiler IR (SSA) to generate backward pass at compile time Full language support (if, while, recursion, closures)

PyTorch's tape-based approach is the most intuitive: Python runs naturally, and the autograd engine records every tensor operation. Call .backward() and the tape replays in reverse. The trade-off is memory — the tape stores every intermediate tensor.

JAX's tracing approach is the most composable. Since grad, jvp, vjp, and vmap are all function transformations on the same IR, they compose freely:

JAX's key internal insight: define simple JVP rules for each primitive, then derive VJP rules automatically by transposing the JVP linear map. This means you only need to implement each derivative rule once.

Zygote.jl operates at the compiler level, rewriting Julia's SSA intermediate representation to insert backward pass computations. This gives it full language support — it can differentiate through if statements, while loops, recursion, struct access, and even foreign function calls — but the implementation is far more complex than tracing.

All three approaches implement the same mathematical operations (JVPs and VJPs) — the difference is when and how the backward pass is constructed. Taping does it at runtime, tracing does it at trace time, and source-to-source does it at compile time.

8. Conclusion

We've covered a lot of ground. Starting from dual numbers that carry derivatives through a single forward pass, to a reverse-mode engine that builds computational graphs and differentiates arbitrary expressions, to techniques that make non-differentiable operations like argmax and sorting gradient-friendly, to the Implicit Function Theorem that lets you differentiate through the solution of an equation without unrolling the solver, to pushing gradients through physics simulations and 3D renderers.

The unifying theme is simple: gradients are a universal optimization signal. If you can make a computation differentiable — whether by design, by smooth relaxation, or by implicit differentiation — you can optimize it. Neural networks are just the beginning. The frontier includes differentiable databases, differentiable compilers, and differentiable program synthesis. Wherever there's a parameterized computation and an objective, differentiable programming provides the bridge.

The tools are mature. JAX gives you composable transforms. PyTorch gives you dynamic graphs. Zygote gives you full language differentiation. The hard part isn't the framework — it's recognizing which parts of your problem can be made differentiable, and choosing the right technique to get gradients flowing.

References & Further Reading