State Space Models from Scratch: How Mamba Learned to Rival Transformers
The Missing Architecture
This series has built two families of sequence models, and each one made a painful tradeoff.
RNNs process tokens one at a time, maintaining a hidden state that compresses the entire history into a fixed-size vector. They're O(n) per sequence — beautiful and memory-efficient. But that sequential dependency means you can't parallelize training, and vanilla RNNs suffer from vanishing gradients that make long-range learning nearly impossible. LSTMs added gating to help, but the serial bottleneck remained.
Transformers solved parallelism with the attention mechanism: every token attends to every other token, computing a weighted sum of values. Training is fully parallel. But that pairwise interaction costs O(n²) in both compute and memory. For a sequence of 4,096 tokens, that's over 16 million attention computations — and the cost quadruples every time you double the context length.
What if there were a third option? An architecture that trains in parallel like a transformer but generates with a constant-size state like an RNN — and does both in O(n) time?
That architecture exists. It comes not from deep learning, but from control theory — the branch of engineering that models circuits, rockets, and autopilots. It's called a State Space Model (SSM), and since 2021 it has gone from a mathematical curiosity to the engine behind Mamba, Jamba, Zamba, and a growing family of models that are genuinely competitive with transformers on language, audio, genomics, and long-context tasks.
The lineage looks like this: HiPPO (2020) discovered how to initialize the state matrix for long-range memory. S4 (2021) turned that into a practical architecture. Mamba (2023) made the system content-aware and beat transformers at their own game. In this post, we'll build all of them from scratch in pure Python and NumPy.
The key insight that makes the whole thing work: a single state space model can be viewed as an ODE, a recurrence, and a convolution — three mathematical representations of the same system, each useful for a different purpose.
The Continuous-Time State Space Model
A state space model is defined by four matrices and a system of two equations. If you've taken a controls or signals class, this will look familiar. If not, don't worry — the intuition is straightforward.
Given an input signal u(t) and a hidden state x(t) ∈ ℝN, the continuous-time SSM is:
x'(t) = Ax(t) + Bu(t) (state equation)
y(t) = Cx(t) + Du(t) (output equation)
Each matrix has a clear job:
- A (N×N) governs how the state evolves on its own — like momentum. Left alone, the state drifts according to A.
- B (N×1) maps input to state — like a steering wheel. The input nudges the state through B.
- C (1×N) maps state to output — a readout projection.
- D (1×1) is a skip connection from input to output. We'll ignore it (set D=0) since the interesting computation flows through the state.
The hidden state x(t) is a compressed summary of the entire input history. Every input u(t) leaves a trace in the state, filtered through the dynamics of A. The output y(t) is just a linear readout of that compressed representation.
Why does "linear" matter? Because the state equation is linear in x — no activation functions, no nonlinearities. This linearity is what unlocks the mathematical tricks that make SSMs fast. It doesn't mean the overall model is limited: we'll stack SSM layers with nonlinearities between them, just like stacking attention layers in a transformer.
Let's simulate a continuous SSM using Euler's method. We'll use a 4-dimensional state with oscillatory dynamics — two pairs of coupled dimensions that create rotating, decaying state trajectories:
import numpy as np
# A 4-dimensional continuous-time SSM with oscillatory dynamics
N = 4 # state dimension
# State matrix: two 2x2 blocks, each a damped rotation
A = np.array([
[-0.5, 1.0, 0.0, 0.0], # block 1: fast oscillation
[-1.0, -0.5, 0.0, 0.0],
[ 0.0, 0.0, -0.1, 0.3], # block 2: slow oscillation
[ 0.0, 0.0, -0.3, -0.1]
])
B = np.array([[1.0], [0.0], [0.5], [0.0]]) # input coupling
C = np.array([[1.0, 0.0, 1.0, 0.0]]) # readout
# Simulate with Euler's method: x(t+dt) ≈ x(t) + dt * x'(t)
dt, T = 0.01, 10.0
steps = int(T / dt)
u = np.zeros(steps)
u[100:200] = 1.0 # pulse input from t=1.0 to t=2.0
x = np.zeros((steps, N))
y = np.zeros(steps)
for k in range(1, steps):
dx = A @ x[k-1] + B.squeeze() * u[k-1]
x[k] = x[k-1] + dt * dx
y[k] = (C @ x[k]).item()
print(f"State at t=5.0: [{', '.join(f'{v:.3f}' for v in x[500])}]")
print(f"Output range: [{y.min():.3f}, {y.max():.3f}]")
print(f"State decays after pulse ends — A's negative eigenvalues ensure stability")
The pulse drives the state away from zero, but A's negative eigenvalues pull it back. The two oscillatory blocks create a rich, multi-scale response from a single pulse — this is the power of a continuous-time dynamical system. But computers don't work in continuous time. We need to discretize.
Discretization: From Continuous to Computable
Neural networks process sequences as discrete tokens: word 1, word 2, word 3. We need to convert our continuous ODE into a step-by-step recurrence that a computer can execute. The bridge between continuous and discrete is discretization.
The idea: given a step size Δ (the time between consecutive tokens), we can derive a discrete system that produces the same output at each timestep. The most common method is the Zero-Order Hold (ZOH): assume the input u(t) is constant between steps (it's held at u_k from time kΔ to (k+1)Δ). Under this assumption, the matrix exponential gives us exact formulas:
Ā = exp(ΔA)
B̄ = (ΔA)−1(Ā − I) · ΔB
After discretization, the continuous ODE becomes a simple recurrence:
xk = Ā xk−1 + B̄ uk
yk = C xk
This is an RNN. Multiply the old state by a matrix, add the input, read out the result. The only difference from a vanilla RNN is where the matrices come from — here they're derived from a continuous-time system via discretization, rather than learned directly.
The step size Δ is crucial. Small Δ means fine-grained resolution (many steps per "event"), large Δ means coarse approximation but faster processing. Mamba will later make Δ input-dependent, letting the model choose its own resolution per token. But for now, it's a fixed hyperparameter.
In practice, computing the exact matrix exponential exp(ΔA) is expensive. There are three common approximations:
- Euler: Ā ≈ I + ΔA — simplest, but can be unstable for large Δ
- Bilinear (Tustin): Ā = (I − ΔA/2)−1(I + ΔA/2) — stable, preserves frequency response. Used by S4.
- Simplified ZOH: Ā = exp(ΔA), B̄ ≈ ΔB — exact A discretization with simplified B. Used by Mamba.
Let's implement the simplified ZOH discretization and verify that the discrete recurrence matches our continuous simulation:
def discretize_zoh(A, B, delta):
"""Discretize a continuous SSM using simplified Zero-Order Hold.
Returns discrete matrices A_bar, B_bar."""
N = A.shape[0]
# Exact: A_bar = expm(delta * A). We use a Padé approximation:
# For small delta*A, expm ≈ I + dA + (dA)^2/2 + (dA)^3/6
dA = delta * A
A_bar = np.eye(N) + dA + dA @ dA / 2 + dA @ dA @ dA / 6
B_bar = delta * B # simplified ZOH (Mamba-style)
return A_bar, B_bar
# Discretize with step size delta = 0.1 (10x coarser than Euler sim)
delta = 0.1
A_bar, B_bar = discretize_zoh(A, B, delta)
# Run discrete recurrence on the same pulse signal
L = 100 # 100 steps at delta=0.1 covers T=10.0
u_disc = np.zeros(L)
u_disc[10:20] = 1.0 # pulse from t=1.0 to t=2.0 (same as continuous)
x_d = np.zeros((L, N))
y_d = np.zeros(L)
for k in range(1, L):
x_d[k] = A_bar @ x_d[k-1] + B_bar.squeeze() * u_disc[k]
y_d[k] = (C @ x_d[k]).item()
# Compare with continuous simulation (sampled at same timesteps)
y_continuous_sampled = y[::10][:L] # sample every 10th point
max_err = np.max(np.abs(y_d[:len(y_continuous_sampled)] - y_continuous_sampled))
print(f"Max error vs continuous: {max_err:.6f}")
print(f"Discrete state shape: {x_d.shape} — just an RNN!")
print(f"Each step: multiply {N}x{N} matrix + add input. That's it.")
The discrete recurrence tracks the continuous simulation closely. The error comes from the simplified B discretization — using the exact formula would make it even smaller. But the simplified version is what Mamba uses, and it works well in practice because the model learns to compensate.
The Dual Computation Trick: Recurrence ↔ Convolution
Here's the insight that makes SSMs special. Because the system is linear, we can unroll the recurrence into a closed-form expression. Watch what happens when we expand the first few timesteps:
- y0 = CB̄ u0
- y1 = CĀB̄ u0 + CB̄ u1
- y2 = C²B̄ u0 + CĀB̄ u1 + CB̄ u2
See the pattern? Each output yk is a weighted sum of all previous inputs, where the weight for input uj at time k is CĀk−jB̄. This is a convolution. The convolution kernel is:
K̄ = (CB̄, CĀB̄, C²B̄, ..., CĀL−1B̄)
Once we have this kernel, we can compute the entire output sequence y = K̄ * u using the Fast Fourier Transform in O(L log L) time. No sequential dependencies, fully parallelizable. This is how SSMs train efficiently.
But here's the beautiful duality: at inference time, we can switch back to the recurrence. Processing one new token just means one matrix multiply and one addition — O(1) per token, with a constant-size state. No growing KV cache, no quadratic blowup.
One model, two computation modes. Train with convolution (parallel). Infer with recurrence (sequential). This is the core reason SSMs exist.
def ssm_kernel(A_bar, B_bar, C, L):
"""Compute the SSM convolution kernel of length L.
K[i] = C @ A_bar^i @ B_bar"""
kernel = np.zeros(L)
A_power = np.eye(A_bar.shape[0]) # A_bar^0 = I
for i in range(L):
kernel[i] = (C @ A_power @ B_bar).item()
A_power = A_power @ A_bar
return kernel
K = ssm_kernel(A_bar, B_bar, C, L)
# Method 1: Convolution via FFT (parallel — training mode)
# Pad to avoid circular convolution artifacts
pad_len = 2 * L
y_conv = np.real(np.fft.ifft(
np.fft.fft(K, pad_len) * np.fft.fft(u_disc, pad_len)
))[:L]
# Method 2: Sequential recurrence (already computed as y_d)
max_diff = np.max(np.abs(y_conv - y_d))
print(f"Convolution vs recurrence max diff: {max_diff:.2e}")
print(f"Kernel K shape: {K.shape}")
print(f"K[0]={K[0]:.4f}, K[1]={K[1]:.4f}, K[2]={K[2]:.4f} — decaying impulse response")
print(f"Train: O(L log L) via FFT | Infer: O(1) per step via recurrence")
The convolution and recurrence outputs match to floating-point precision. This isn't an approximation — it's a mathematical identity. The kernel K̄ is a decaying impulse response: it tells you how much each past input contributes to the current output. Early entries are large (recent inputs matter more), and later entries decay toward zero (the system gradually forgets).
There's a catch, though. Computing the kernel requires Āk for all k up to L. Naively, that's O(N³L) matrix multiplications. For a state dimension of N=64 and sequence length L=4096, that's brutal. S4's breakthrough was finding a way to compute this kernel in O(N log L) using special structure in the A matrix. Let's see how.
Try It: Continuous → Discrete → Convolve
Three views of the same SSM. The slider controls the discretization step size Δ — watch how coarser steps approximate the continuous trajectory and change the convolution kernel.
S4 and HiPPO: Remembering Long Sequences
If you initialize A with random values, your SSM will suffer the same fate as vanilla RNNs: the state decays exponentially, and the model can't remember anything more than a few timesteps back. This is the vanishing gradient problem in a new disguise.
The breakthrough came from Albert Gu's HiPPO framework (High-order Polynomial Projection Operators, 2020). HiPPO answers a specific question: what's the optimal way to compress the entire history of an input signal into a fixed-size state vector?
The answer: project the input history onto a basis of Legendre polynomials. Each element of the hidden state stores a coefficient of a different Legendre polynomial, capturing the input history at a different timescale. Element 0 stores the "average" of the input. Element 1 stores the "slope." Element 2 stores the "curvature." Higher elements capture finer details. Together, they form an optimal polynomial approximation of the full input history.
The HiPPO-LegS (Legendre-Scaled) matrix that achieves this has a beautiful closed-form structure:
Ank = −√(2n+1) · √(2k+1) for n > k
Ann = −(n+1) (diagonal)
Bn = √(2n+1)
Notice the negative signs everywhere — they ensure stability (eigenvalues with negative real parts, so the state doesn't explode). The matrix is lower triangular: element n depends on all elements k < n, creating a cascading information flow from coarse to fine timescales.
S4 (Structured State Spaces for Sequence Modeling, Gu et al. 2021) built on HiPPO by showing that the matrix has a special Normal-Plus-Low-Rank (NPLR) structure that enables O(N + L) kernel computation via a Cauchy kernel formula. This was the breakthrough that made SSMs competitive on real benchmarks like the Long Range Arena (sequences up to 16K tokens).
Later work simplified things further: DSS and S5 showed you can just use a diagonal complex A matrix initialized near the eigenvalues of HiPPO — simpler, nearly as effective, and much easier to implement. Let's compare HiPPO initialization against random:
def make_hippo(N):
"""Build the HiPPO-LegS matrix for state dimension N.
Compresses input history into Legendre polynomial coefficients."""
A = np.zeros((N, N))
B = np.zeros((N, 1))
for n in range(N):
B[n, 0] = np.sqrt(2 * n + 1)
A[n, n] = -(n + 1) # diagonal
for k in range(n):
A[n, k] = -np.sqrt(2*n + 1) * np.sqrt(2*k + 1) # lower triangle
return A, B
N_state = 16
A_hippo, B_hippo = make_hippo(N_state)
# Memory test: input a burst, then silence. Who remembers?
L_test = 300
signal = np.zeros(L_test)
signal[20:40] = np.sin(np.linspace(0, 2 * np.pi, 20)) # burst at steps 20-40
delta_test = 0.05
C_read = np.ones((1, N_state)) / np.sqrt(N_state) # normalized readout
# HiPPO SSM
Ah, Bh = discretize_zoh(A_hippo, B_hippo, delta_test)
xh = np.zeros(N_state)
y_hippo = np.zeros(L_test)
for k in range(L_test):
xh = Ah @ xh + Bh.squeeze() * signal[k]
y_hippo[k] = (C_read @ xh).item()
# Random SSM (diagonal with fast-decaying eigenvalues)
np.random.seed(0)
A_rand = np.diag(-np.abs(np.random.randn(N_state)) * 5.0 - 3.0) # stable, fast decay
Ar, Br = discretize_zoh(A_rand, B_hippo, delta_test)
xr = np.zeros(N_state)
y_rand = np.zeros(L_test)
for k in range(L_test):
xr = Ar @ xr + Br.squeeze() * signal[k]
y_rand[k] = (C_read @ xr).item()
print(f"Signal energy concentrated at steps 20-40")
print(f"At step 100 (60 steps AFTER signal ended):")
print(f" HiPPO output: {abs(y_hippo[100]):.6f}")
print(f" Random output: {abs(y_rand[100]):.6f}")
print(f"HiPPO remembers — random forgets.")
HiPPO retains information about the burst long after it ended, because each state dimension captures a different timescale of the input's history. The random matrix's state decays to near-zero within a few dozen steps. This is exactly why HiPPO initialization was critical to making SSMs work for long sequences — it's the SSM equivalent of residual connections for deep networks.
The Problem with Fixed Dynamics
S4 was a breakthrough, but it has a fundamental limitation: the matrices A, B, and C are fixed for every input token. The system is Linear Time-Invariant (LTI) — the same filter is applied regardless of what the input actually says.
For audio processing or time-series forecasting, this can be fine. But for language, it's a dealbreaker. Consider the sentence "The dragon breathed fire." The word "dragon" is far more important for understanding the sentence than "the." An LTI system treats both identically — the same A dynamics, the same B coupling, the same C readout. It's like wearing sunglasses with fixed tint in every lighting condition: sometimes it works, often it doesn't.
Attention doesn't have this problem. Every attention head computes query-key-value projections from the input itself, so the processing is inherently content-aware. That's a major reason transformers work so well for language — they adapt their computation to the content.
The question became: can we make an SSM content-aware while keeping its O(n) efficiency? The answer is yes, but it comes at a cost — making the parameters input-dependent breaks the convolution trick. The system is no longer time-invariant, so we can't precompute a single kernel K̄. This is the tension that Mamba resolves.
Mamba: The Selective State Space Model
In December 2023, Albert Gu and Tri Dao published Mamba, and the ML world took notice. Mamba-3B matched the quality of transformer models twice its size, with 5× higher inference throughput. The core idea: make B, C, and Δ input-dependent, while keeping A fixed.
Here's how it works. At each timestep t, the input xt is projected to produce three sets of parameters:
- Bt = Linear(xt) — controls what to write into the state
- Ct = Linear(xt) — controls what to read from the state
- Δt = softplus(Linear(xt)) — controls how much to update the state
A stays fixed (log-parameterized, diagonal). This is a deliberate design choice: A governs the baseline decay rate of the state, which should be stable and learnable as a global parameter. The input-dependent parts — B, C, Δ — handle the content-aware processing.
The selection mechanism is elegant. After discretization with input-dependent Δt:
- Large Δ: Ā ≈ 0 (the old state is mostly erased), B̄ is large (the new input is strongly absorbed). The model is paying attention to this token — it overwrites the state.
- Small Δ: Ā ≈ I (the old state is preserved), B̄ is small (the new input barely registers). The model is ignoring this token — the state passes through unchanged.
This is exactly the same idea as an LSTM's forget gate, but derived from a continuous-time perspective rather than an ad hoc gating mechanism. The network learns which tokens matter by modulating how much it updates its state.
How does this compare to attention? Attention computes explicit pairwise interactions between every token pair — O(n²) cost. Mamba updates a fixed-size state per token — O(n) cost. The tradeoff: Mamba can't do arbitrary token-to-token lookup (there's no "attend to token 47 specifically"), but it can process sequences of unlimited length in linear time. The state is the cache, and it's constant-size regardless of context length.
The full Mamba block wraps the selective SSM in a gated architecture: input → linear expand → (conv1d + SiLU + selective SSM) ⊗ (linear + SiLU gate) → linear project. This structure parallels the transformer block (attention + FFN), with RMSNorm for normalization.
def selective_scan(x_seq, A_log, W_B, W_C, W_delta, b_delta):
"""Mamba's selective scan — input-dependent state space model.
Args:
x_seq: (L, D) input sequence
A_log: (D, N) log-parameterized state decay (fixed, not input-dependent)
W_B: (D, N) projection for input-dependent B
W_C: (D, N) projection for input-dependent C
W_delta: (D, 1) projection for input-dependent step size
b_delta: scalar bias for step size
"""
L, D = x_seq.shape
N = A_log.shape[1]
h = np.zeros((D, N)) # hidden state — constant size!
outputs = np.zeros((L, D))
deltas = np.zeros(L)
for t in range(L):
xt = x_seq[t]
# Input-dependent projections — the "selection"
B_t = xt @ W_B # (N,)
C_t = xt @ W_C # (N,)
delta_t = np.log1p(np.exp(xt @ W_delta + b_delta)).mean() # softplus
# Discretize for THIS timestep (not globally!)
A_bar = np.exp(delta_t * A_log) # (D, N) element-wise
B_bar = delta_t * B_t # (N,) simplified ZOH
# State update
h = A_bar * h + np.outer(xt, B_bar) # (D, N)
# Output
outputs[t] = (h * C_t).sum(axis=1) # (D,)
deltas[t] = delta_t
return outputs, deltas
# Demo: 8 tokens with varying importance
D, N, L = 4, 8, 8
np.random.seed(123)
A_log = -np.abs(np.random.randn(D, N)) # negative for stability
W_B = np.random.randn(D, N) * 0.1
W_C = np.random.randn(D, N) * 0.1
W_delta = np.random.randn(D, 1) * 0.5
b_delta = -1.5 # moderate default step
x_seq = np.random.randn(L, D) * 0.5
x_seq[2] *= 5.0 # "dragon" — high magnitude → should get high delta
x_seq[5] *= 5.0 # "fire" — high magnitude → should get high delta
y, deltas = selective_scan(x_seq, A_log, W_B, W_C, W_delta, b_delta)
print("Delta per token:", [f"{d:.3f}" for d in deltas])
print("Tokens 2 and 5 get higher delta — the model 'selects' them")
print(f"State shape: ({D}, {N}) = {D*N} values — constant regardless of L")
Notice how tokens 2 and 5 (with higher magnitude, simulating "important" words) get larger Δ values, meaning the state updates more for those tokens. In a trained model, the projections W_B, W_C, and W_delta learn to assign high Δ to content words and low Δ to function words — an emergent attention-like behavior, but at O(n) cost.
Try It: Selection Mechanism Explorer
See how Mamba's selection mechanism assigns different Δ values to each token. High Δ (red) = "pay attention" = large state update. Low Δ (blue) = "ignore" = state passes through.
The Hardware-Aware Parallel Scan
We said that making parameters input-dependent breaks the convolution trick. No global kernel K̄ exists, because Bt, Ct, and Δt are different at every timestep — the system is no longer time-invariant. Are we stuck with slow sequential processing?
Not quite. The linear recurrence xk = ak xk−1 + bk has a special property: it's associative. If we define the operator:
(a1, b1) ∘ (a2, b2) = (a2 · a1, a2 · b1 + b2)
then combining two adjacent recurrence steps into one is just applying this operator. And because it's associative, we can compute all prefix combinations using a parallel prefix scan — the same algorithm that GPUs use for cumulative sums and sorting. Instead of O(L) sequential steps, we get O(log L) depth with O(L) total work.
Gu and Dao took this further with GPU kernel engineering. The full Mamba state tensor has shape (batch, length, dimension, state_size) — far too large to materialize in GPU HBM (high-bandwidth memory). Their custom CUDA kernel performs all three operations — discretize, scan, output projection — fused in GPU SRAM (fast but tiny on-chip memory), avoiding the memory bottleneck entirely. The backward pass recomputes states from inputs rather than storing them, trading compute for memory.
This is systems engineering meeting machine learning: the algorithm is designed around the hardware constraints, not the other way around. Let's implement a simplified parallel scan to see the associative property in action:
def sequential_scan(a, b):
"""Sequential: x_k = a_k * x_{k-1} + b_k, with x_{-1} = 0"""
L = len(a)
x = np.zeros(L)
x[0] = b[0]
for k in range(1, L):
x[k] = a[k] * x[k-1] + b[k]
return x
def parallel_scan(a, b):
"""Parallel prefix scan using the associative operator:
(a1, b1) ∘ (a2, b2) = (a2*a1, a2*b1 + b2)
Solves x_k = a_k*x_{k-1} + b_k in O(log L) parallel depth."""
L = len(a)
aa, bb = a.copy(), b.copy()
for d in range(int(np.ceil(np.log2(L)))):
stride = 1 << d # 1, 2, 4, 8, ...
aa_prev, bb_prev = aa.copy(), bb.copy()
for i in range(stride, L):
# Combine element (i-stride) into element i
bb[i] = aa_prev[i] * bb_prev[i - stride] + bb_prev[i]
aa[i] = aa_prev[i] * aa_prev[i - stride]
return bb # bb[k] = x_k
# Verify
L = 16
np.random.seed(7)
a = np.random.uniform(0.8, 0.99, L) # decay coefficients
b = np.random.randn(L) * 0.3 # inputs
x_seq = sequential_scan(a, b)
x_par = parallel_scan(a, b)
print(f"Sequential: {np.round(x_seq[:6], 4)}")
print(f"Parallel: {np.round(x_par[:6], 4)}")
print(f"Max error: {np.max(np.abs(x_seq - x_par)):.2e}")
print(f"Depth: O(log {L}) = {int(np.ceil(np.log2(L)))} steps vs {L} sequential")
Both methods produce identical results, but the parallel scan has O(log L) depth instead of O(L). On a GPU with thousands of cores, this translates directly to wall-clock speedup. Combined with the fused SRAM kernel, Mamba's selective scan runs nearly as fast as an optimized attention implementation — and unlike attention, it doesn't slow down quadratically with sequence length.
The Complexity Race: Concrete Numbers
Let's put real numbers on the comparison. Consider three architectures processing a sequence of length L with model dimension D and SSM state dimension N:
| Operation | Transformer (Attention) | S4 (Convolution) | Mamba (Selective Scan) |
|---|---|---|---|
| Training FLOPs | O(L²D) | O(L log L · DN) | O(LDN) |
| Inference per token | O(LD) + KV cache | O(DN) | O(DN) |
| State size per sequence | O(LD) — growing KV cache | O(DN) — fixed | O(DN) — fixed |
| Content-aware? | Yes (Q/K/V from input) | No (fixed kernel) | Yes (B, C, Δ from input) |
Now plug in realistic values. With L=1024, D=512, and N=16:
- Attention: L² × D = 1024² × 512 ≈ 537 million multiply-adds
- Mamba: L × D × N = 1024 × 512 × 16 ≈ 8.4 million multiply-adds
- That's 64× fewer operations for Mamba.
At L=4096, attention costs 4× more (quadratic scaling) while Mamba costs 4× more (linear scaling). The gap doubles with every doubling of sequence length. At L=16384, attention is 256× more expensive.
The memory story is equally dramatic. Transformers need a KV cache that stores all previous keys and values — it grows linearly with context length and dominates GPU memory for long sequences. Mamba stores a fixed-size state h ∈ ℝD×N = 512 × 16 = 8,192 values. That's it. Generate a million tokens, and the state is still 8,192 values.
Real-world results from the Mamba paper: Mamba-3B matches Transformer-7B quality on language modeling benchmarks, with 5× higher inference throughput on long sequences. The linear scaling means the gap only grows with context length.
Try It: Attention vs SSM — The Complexity Race
Drag the slider to increase sequence length and watch the computational cost diverge. Left: attention's O(n²) matrix fills in quadratically. Right: SSM's O(n) state updates linearly.
Putting It Together: A Character-Level SSM
Let's tie everything together by building a minimal language model using the selective SSM. We'll embed characters, pass them through a Mamba-style SSM layer, and predict the next character. This is the same kind of model we built with RNNs in the earlier series post — but now powered by a selective state space.
# Minimal character-level SSM language model
text = "to be or not to be that is the question " * 5
chars = sorted(set(text))
V = len(chars) # vocab size
ch2ix = {c: i for i, c in enumerate(chars)}
data = np.array([ch2ix[c] for c in text])
# Model hyperparameters
D_emb, N_st = 16, 8
np.random.seed(42)
# Learnable parameters
W_emb = np.random.randn(V, D_emb) * 0.1 # character embeddings
A_log = -np.abs(np.random.randn(D_emb, N_st)) # state decay (fixed per layer)
W_B = np.random.randn(D_emb, N_st) * 0.05 # B projection
W_C = np.random.randn(D_emb, N_st) * 0.05 # C projection
W_dt = np.random.randn(D_emb, 1) * 0.05 # delta projection
b_dt = np.full(1, -3.0) # delta bias
W_out = np.random.randn(D_emb, V) * 0.1 # output head
def softmax(z):
e = np.exp(z - z.max())
return e / e.sum()
def forward_ssm(indices):
"""Forward pass: embed → selective SSM → predict next char."""
L = len(indices)
h = np.zeros((D_emb, N_st)) # SSM hidden state
total_loss = 0.0
for t in range(L - 1):
x = W_emb[indices[t]] # (D_emb,)
# Selective SSM step
B_t = x @ W_B # (N_st,)
C_t = x @ W_C # (N_st,)
dt = np.log1p(np.exp(x @ W_dt + b_dt)).mean() # softplus
A_bar = np.exp(dt * A_log) # (D_emb, N_st)
h = A_bar * h + np.outer(x, dt * B_t) # state update
y = (h * C_t).sum(axis=1) # readout (D_emb,)
# Next-character prediction
logits = y @ W_out # (V,)
probs = softmax(logits)
total_loss -= np.log(probs[indices[t + 1]] + 1e-8)
return total_loss / (L - 1)
loss = forward_ssm(data)
print(f"Initial loss: {loss:.3f} (random baseline: {np.log(V):.3f})")
print(f"Vocab: {V} chars | Embedding: {D_emb}d | State: {D_emb}x{N_st} = {D_emb*N_st}")
print(f"The hidden state h is FIXED SIZE — no KV cache, no growing memory.")
print(f"Generate a million tokens? Still just {D_emb*N_st} state values.")
The initial loss should be near the random baseline of log(vocab_size), confirming the model is working but untrained. In a full implementation with backpropagation and an Adam optimizer, this model would learn to predict characters with decreasing loss. The key point: the SSM's hidden state is a fixed (16×8 = 128)-value summary of the entire input history. Compare this to a transformer, which would need a KV cache that grows with every token generated.
Connections to the Series
State Space Models sit at the intersection of nearly every topic we've covered. Here's how they connect:
- Recurrent Neural Networks from Scratch — SSMs are the spiritual successor to RNNs. Both maintain a hidden state updated per token. Mamba's selection mechanism is principled gating — the same idea as LSTM's forget gate, but derived from continuous-time dynamics rather than ad hoc design.
- Attention from Scratch — Attention computes pairwise token interactions explicitly at O(n²) cost. SSMs compress all history into a fixed-size state at O(n) cost. Two fundamentally different approaches to the same problem: modeling dependencies in sequences.
- KV Cache from Scratch — Transformers need a KV cache that grows with context length. SSMs need no cache — the state is the cache, and it's constant-size. This is why Mamba has 5× inference throughput on long sequences.
- Positional Encoding from Scratch — SSMs don't need explicit positional encoding. Position information is baked into the state dynamics through A and HiPPO — the continuous-time formulation naturally handles variable-length sequences.
- Normalization from Scratch — Mamba uses RMSNorm between layers, the same normalization technique that modern transformers use. Simple, effective, and slightly cheaper than LayerNorm.
- Feed-Forward Networks from Scratch — Mamba blocks alternate between selective SSM layers and gated MLPs, paralleling how transformer blocks alternate attention and FFN. The gated MLP structure (SiLU activation, linear gate) is shared between both architectures.
- Optimizers from Scratch — SSMs use the same Adam optimizer as transformers, but often train faster because there's no quadratic attention bottleneck to eat up FLOPs during training.
- Mixture of Experts from Scratch — Jamba (AI21, 2024) combines Mamba layers with attention layers and MoE routing — a hybrid architecture that gets the best of all worlds: linear-time SSM for most layers, targeted attention for tasks that need it, and sparse MoE for parameter efficiency.
- Speculative Decoding from Scratch — SSMs enable faster draft generation for speculative decoding because per-token state updates are cheaper than attention. A small SSM draft model paired with a large transformer verifier is a natural fit.
- Softmax & Temperature from Scratch — Mamba uses softplus for Δ parameterization — a smooth approximation to ReLU, ensuring the step size is always positive. Contrast this with softmax's use in attention, where the goal is a probability distribution over tokens.
- Contrastive Learning from Scratch — SSM encoders can replace transformer encoders in contrastive frameworks like SimCLR and CLIP, offering linear-time feature extraction over sequences. The encoder architecture is swappable — what matters is the contrastive objective.
- Micrograd from Scratch — The autograd engine that started this series. SSMs add a new computational pattern to differentiate through: the scan operation. Custom backward passes for scans (recomputing states rather than storing them) are one of Mamba's key engineering innovations.
References & Further Reading
- Gu et al. — "HiPPO: Recurrent Memory with Optimal Polynomial Projections" (NeurIPS 2020) — The foundation: how to initialize state matrices for optimal long-range memory.
- Gu et al. — "Efficiently Modeling Long Sequences with Structured State Spaces" (S4, ICLR 2022) — The breakthrough that made SSMs competitive, introducing the NPLR structure for efficient kernel computation.
- Gu & Dao — "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" (2023) — The selective SSM that rivals transformers. Introduced input-dependent dynamics and the hardware-aware scan.
- Dao & Gu — "Transformers are SSMs: Generalized Models and Efficient Algorithms through Structured State Space Duality" (Mamba-2, 2024) — Shows the deep connection between attention and state spaces, unifying both under a common framework.
- Sasha Rush — "The Annotated S4" — A beautifully annotated implementation of S4 in JAX. The gold standard for understanding the original algorithm.
- Maarten Grootendorst — "A Visual Guide to Mamba and State Space Models" — Excellent visual explanations of the SSM pipeline and selection mechanism.
- Hazy Research Blog — S4 and SSM series — The research group behind HiPPO, S4, and Mamba. Deep dives into the math and engineering.
- Lieber et al. — "Jamba: A Hybrid Transformer-Mamba Language Model" (AI21, 2024) — Combining Mamba layers with attention and MoE for a production-scale hybrid architecture.
The shift from quadratic attention to linear state spaces may turn out to be one of the most important architectural changes in deep learning's history. Or transformers may adapt and reclaim the throne (FlashAttention already narrowed the gap for moderate sequence lengths). Either way, understanding SSMs gives you a fundamentally different lens for thinking about sequence modeling — one rooted in the elegant mathematics of dynamical systems, not just pattern matching over token pairs.
The missing architecture isn't missing anymore.