Speculative Decoding from Scratch: How LLMs Generate Text 2-3x Faster
The Autoregressive Bottleneck
Your GPU is coasting at less than 1% utilization while generating text. Not a typo — less than one percent.
Across this series, we've built every piece of the transformer inference pipeline: tokenization, embeddings, positional encoding, attention, and a KV cache to avoid redundant computation. We've explored every decoding strategy — greedy, beam search, top-k, nucleus sampling. But there's a deeper bottleneck we haven't addressed.
Autoregressive generation is fundamentally sequential. Token t+1 depends on token t, which depends on token t-1, all the way back to the prompt. Each token requires a full forward pass through billions of parameters. You can't skip ahead. You can't parallelize. One token at a time, every time.
Or can you?
Speculative decoding is an elegant trick that generates 2-3x more tokens per second with mathematically identical output. No approximations, no quality loss — the exact same probability distribution, just faster. The idea: use a small, fast "draft" model to propose multiple tokens, then verify them all at once with the big model. When the draft is right (and it usually is), you get multiple tokens for the cost of one.
Let's build it from scratch.
Why Decoding is Memory-Bandwidth Bound
To understand why speculative decoding works, we need to understand why standard decoding is so slow. The answer lies in a concept called arithmetic intensity — the ratio of computation (FLOPs) to memory traffic (bytes loaded).
When generating a single token, the model loads all its weights from GPU memory, performs a relatively tiny amount of math, and produces one output. For a 7B parameter model in FP16:
model_params = 7e9 # 7 billion parameters
bytes_per_param = 2 # FP16 = 2 bytes each
model_size_bytes = model_params * bytes_per_param # 14 GB
# DECODE: generate one token
# Load all weights, perform ~2 FLOPs per parameter
decode_flops = 2 * model_params # ~14 GFLOPs
decode_bytes = model_size_bytes # ~14 GB loaded from memory
decode_intensity = decode_flops / decode_bytes
print(f"Decode: {decode_intensity:.1f} FLOP/byte")
# PREFILL: process a 1000-token prompt
# Load weights ONCE, but do N times more work
N = 1000
prefill_flops = N * 2 * model_params # ~14 TFLOPs
prefill_bytes = model_size_bytes # ~14 GB (same weights, loaded once)
prefill_intensity = prefill_flops / prefill_bytes
print(f"Prefill: {prefill_intensity:.0f} FLOP/byte")
# A100 GPU specs
a100_tflops = 312e12 # 312 TFLOPS peak compute
a100_bandwidth = 2e12 # 2 TB/s memory bandwidth
# Utilization = min(1, intensity * bandwidth / compute)
decode_util = (decode_intensity * a100_bandwidth) / a100_tflops
print(f"\nDecode GPU utilization: {decode_util:.1%}")
print(f"Prefill GPU utilization: ~100%")
Output:
# Decode: 1.0 FLOP/byte
# Prefill: 1000 FLOP/byte
#
# Decode GPU utilization: 0.6%
# Prefill GPU utilization: ~100%
During decode, the GPU performs just 1 floating-point operation for every byte it loads. An A100 can compute 312 trillion operations per second, but can only move 2 trillion bytes per second. At 1 FLOP/byte, you're bottlenecked by memory bandwidth — the GPU's massive compute array sits idle, waiting for data to trickle in from memory. Only 0.6% of the hardware is doing useful work.
Prefill is the opposite. Processing a 1000-token prompt loads the weights once but performs 1000x more math per byte — the GPU is fully saturated. This is why "time to first token" feels fast but generation feels slow.
Here's the key insight: verifying K draft tokens looks like prefill, not decode. You feed K tokens into the target model in a single forward pass — one weight load, K times the computation. If K=5, the verification has 5x the arithmetic intensity of a single decode step, at essentially the same wall-clock cost. Loading 14 GB of weights takes ~7ms no matter if you're processing 1 token or 5. The extra FLOPs for 5 tokens are negligible because the GPU has overwhelming compute capacity sitting idle.
The Core Idea: Draft and Verify
Think of it like a junior developer and a senior developer working together. The junior writes code quickly — maybe not perfect, but fast. Instead of the senior writing everything line by line, they review the junior's entire batch at once. If the code is good, it ships instantly. If the senior spots a mistake, they fix just that one line and move on.
In speculative decoding, the "junior developer" is a small draft model (say, 160M parameters) and the "senior developer" is your large target model (say, 7B parameters). They share the same vocabulary. Here's the loop:
Step 1 — Draft: The small model generates K candidate tokens autoregressively. This is fast because the model is tiny — a 160M model runs roughly 40x faster than a 7B model per token.
Step 2 — Verify: Feed all K draft tokens into the large model in one forward pass. Because the model processes them in parallel (like prefill), this takes about the same time as generating a single token. The target model outputs probability distributions at all K+1 positions.
Step 3 — Accept or reject: For each draft token, compare what the draft model thought (probability q) to what the target model thinks (probability p). Accept good drafts, reject at the first disagreement, and sample a correction. If all K are accepted, grab a bonus (K+1)th token from the target's final distribution.
The result? Between 1 and K+1 tokens per target model call, instead of exactly 1. In pseudocode:
def speculative_decode_step(prefix, draft_model, target_model, K):
"""One step: returns 1 to K+1 new tokens."""
# ── Draft: K tokens from small model (cheap) ──
draft_tokens, draft_probs = [], []
context = list(prefix)
for _ in range(K):
q = draft_model.predict(context) # fast!
token = sample_from(q)
draft_tokens.append(token)
draft_probs.append(q)
context.append(token)
# ── Verify: score all K positions in ONE target pass ──
# This is the magic — parallel, like prefill
target_dists = target_model.verify(prefix, draft_tokens)
# target_dists[i] gives the target's distribution at position i
# ── Accept/reject via rejection sampling ──
accepted = []
for i in range(K):
if accept(target_dists[i], draft_probs[i], draft_tokens[i]):
accepted.append(draft_tokens[i])
else:
correction = sample_correction(target_dists[i], draft_probs[i])
accepted.append(correction)
return accepted # stop at first rejection
# All K accepted! Bonus token from position K+1
bonus = sample_from(target_dists[K])
accepted.append(bonus)
return accepted # K+1 tokens!
When the draft model closely matches the target — and for many tokens, it does — we get K+1 tokens for the cost of one target model call plus K very cheap draft calls. That's the source of the 2-3x speedup.
The Rejection Sampling Math
Here's the mathematical heart of speculative decoding, and it's beautiful. We drafted a token from distribution q(x) (the draft model), but we want the output to follow distribution p(x) (the target model) exactly. Not approximately — exactly. How?
The Acceptance Rule
For a draft token x, we accept it with probability:
The intuition is simple:
- When
q(x) ≤ p(x)— the draft underestimated this token's probability. The target likes it even more than the draft did. Always accept. α = 1. - When
q(x) > p(x)— the draft overestimated this token. The target thinks it's less likely. Accept with probabilityp(x)/q(x) < 1, proportional to how much the target agrees.
The Correction Distribution
When we reject, we don't just give up — we sample a correction token from the residual distribution:
This residual captures exactly the "gap" between what the target wants and what the draft provided. It contains probability mass only on tokens that the target model likes more than the draft — the tokens the draft was underweighting.
The Losslessness Proof
Here's the punchline. For any token x in the vocabulary, the total probability of outputting x is:
The first term (green) is the probability of drafting x and accepting it: q(x) · min(1, p(x)/q(x)) = min(q(x), p(x)). The second term (blue) is the probability of rejecting some draft token and then sampling x as the correction. Together they exactly reconstruct p(x), because for any non-negative numbers a and b:
That's it. No approximation, no error bound, no "close enough." The output distribution is identical to running the target model alone. Let's verify this with concrete numbers:
import numpy as np
def rejection_sample(p, q, draft_token):
"""
Speculative decoding rejection sampling for one position.
p: target distribution (array over vocabulary)
q: draft distribution (array over vocabulary)
draft_token: index sampled from q
Returns: (accepted: bool, output_token: int)
"""
# Accept with probability min(1, p/q) for the drafted token
accept_prob = min(1.0, p[draft_token] / q[draft_token])
if np.random.random() < accept_prob:
return True, draft_token
# Rejected — sample from residual distribution
residual = np.maximum(0, p - q)
residual /= residual.sum()
correction = np.random.choice(len(p), p=residual)
return False, correction
Let's walk through a concrete example with a tiny 4-token vocabulary:
# Vocabulary: ["the", "cat", "sat", "dog"]
p = np.array([0.50, 0.20, 0.10, 0.20]) # target model
q = np.array([0.40, 0.30, 0.20, 0.10]) # draft model
# Draft model sampled "cat" (index 1) with q("cat") = 0.30
draft_token = 1
# Accept probability: min(1, p("cat")/q("cat")) = min(1, 0.20/0.30)
accept_prob = min(1.0, p[draft_token] / q[draft_token])
print(f"Accept probability for 'cat': {accept_prob:.3f}") # 0.667
# If rejected, sample from residual:
residual = np.maximum(0, p - q)
print(f"Raw residual: {residual}")
# [max(0, 0.50-0.40), max(0, 0.20-0.30), max(0, 0.10-0.20), max(0, 0.20-0.10)]
# = [0.10, 0.00, 0.00, 0.10]
residual_norm = residual / residual.sum()
print(f"Correction distribution: {residual_norm}")
# = [0.50, 0.00, 0.00, 0.50]
# Correction chooses "the" or "dog" — tokens the draft underweighted
# Verify losslessness for each token:
for i, word in enumerate(["the", "cat", "sat", "dog"]):
total = min(p[i], q[i]) + max(0, p[i] - q[i])
print(f"P('{word}') = min({p[i]:.2f}, {q[i]:.2f}) + "
f"max(0, {p[i]:.2f}-{q[i]:.2f}) = {total:.2f} == p = {p[i]:.2f} ✓")
Every token's output probability exactly equals the target distribution. The draft model's opinion has been completely washed out by the accept/reject step — it only affects speed, never quality.
Building the Full Algorithm
Now let's assemble the complete speculative decoding loop. This is the real implementation — handling the full K-token draft, sequential verification, and the bonus token:
import numpy as np
def speculative_decode(prompt_tokens, draft_model, target_model,
K=5, max_new_tokens=100):
"""
Complete speculative decoding loop.
draft_model.predict(context) -> probability distribution over vocab
target_model.verify(prefix, draft_tokens) -> list of (K+1) distributions
"""
generated = list(prompt_tokens)
while len(generated) - len(prompt_tokens) < max_new_tokens:
start = len(generated)
# ── Phase 1: Draft K tokens autoregressively (cheap) ──
draft_tokens = []
draft_dists = []
ctx = list(generated)
for _ in range(K):
q = draft_model.predict(ctx)
token = sample(q)
draft_tokens.append(token)
draft_dists.append(q)
ctx.append(token)
# ── Phase 2: Verify all K in ONE target model pass ──
target_dists = target_model.verify(generated, draft_tokens)
# target_dists[i] = target's distribution at position (start + i)
# We get K+1 distributions: positions 0..K
# ── Phase 3: Accept/reject with rejection sampling ──
n_accepted = 0
for i in range(K):
p = target_dists[i]
q = draft_dists[i]
token = draft_tokens[i]
accept_prob = min(1.0, p[token] / q[token])
if np.random.random() < accept_prob:
generated.append(token) # accept draft
n_accepted += 1
else:
# Reject — sample correction from residual
residual = np.maximum(0, p - q)
residual /= residual.sum()
correction = np.random.choice(len(p), p=residual)
generated.append(correction) # append correction
break # stop verifying
# ── Bonus: if ALL K accepted, free extra token ──
if n_accepted == K:
bonus = sample(target_dists[K])
generated.append(bonus) # K+1 tokens this step!
# Roll back draft model's KV cache to match accepted length
draft_model.rollback_cache(start + n_accepted + 1)
return generated
A few critical details to note:
- The bonus token is free. When all K drafts are accepted, the target model's forward pass already computed the distribution at position K+1. We sample from it — no extra cost.
- KV cache management matters. Both the draft and target models maintain KV caches. On rejection, the draft model's cache needs to be rolled back to match only the accepted tokens. The target model's cache grows by however many tokens we accepted.
- Temperature and sampling compose. Speculative decoding works with any sampling strategy — temperature, top-k, nucleus sampling. Just apply the same transformations to both
pandqbefore rejection sampling. Higher temperature actually increases acceptance rate, because flatter distributions are easier for the draft model to match.
Speedup Analysis: When Does It Help?
How fast is speculative decoding? The expected number of tokens per step has an elegant closed form. If each draft token is accepted independently with probability α, then:
This is a geometric series, and the intuition is beautiful: you always get at least 1 token (either a correction or the first accepted draft). You get a second token with probability α (first draft accepted). A third with probability α² (first two accepted). And so on, up to K+1 tokens with probability αK (all K accepted, plus the bonus).
The wall-clock speedup depends on how cheap the draft model is. If the draft model takes a fraction c of the target model's latency per token, one speculative step costs K·c + 1 time units (K draft calls plus 1 target verify). The speedup is:
def expected_speedup(alpha, K, c=0.1):
"""
alpha: acceptance rate (0 to 1)
K: speculation length (number of draft tokens)
c: draft_latency / target_latency (typically 0.05-0.2)
"""
if abs(alpha - 1.0) < 1e-10:
expected_tokens = K + 1
else:
expected_tokens = (1 - alpha ** (K + 1)) / (1 - alpha)
cost = K * c + 1 # K draft calls + 1 target verify
return expected_tokens / cost
# Sweep parameters
print(f"{'α':>6} {'K=2':>8} {'K=3':>8} {'K=5':>8} {'K=8':>8}")
print("-" * 40)
for alpha in [0.5, 0.7, 0.8, 0.9, 0.95]:
row = f"{alpha:>6.2f}"
for K in [2, 3, 5, 8]:
s = expected_speedup(alpha, K, c=0.1)
row += f" {s:>7.2f}x"
print(row)
With c = 0.1 (draft model 10x faster than target — typical for a 160M draft with a 7B target):
| α \ K | K = 2 | K = 3 | K = 5 | K = 8 |
|---|---|---|---|---|
| 0.50 | 1.46x | 1.44x | 1.31x | 1.11x |
| 0.70 | 1.83x | 1.95x | 1.96x | 1.78x |
| 0.80 | 2.03x | 2.27x | 2.46x | 2.41x |
| 0.90 | 2.26x | 2.65x | 3.12x | 3.40x |
| 0.95 | 2.38x | 2.85x | 3.53x | 4.11x |
Several patterns jump out:
- Acceptance rate is king. Going from α=0.7 to α=0.9 nearly doubles the speedup at K=5 (1.96x → 3.12x). A well-matched draft model matters more than tuning K.
- There's an optimal K. At low acceptance rates (α=0.5), K=2 is best — longer speculation wastes draft compute on tokens that get rejected. At high acceptance rates (α=0.9+), K=5-8 continues to pay off.
- Diminishing returns. Going from K=5 to K=8 barely helps at α=0.8 (2.46x → 2.41x — it actually gets slightly worse because the extra draft calls aren't worth it).
When speculative decoding shines: single-request inference, interactive chat, long-form generation — anywhere the model generates one token at a time with low batch size. When it doesn't help: high-batch-size serving, where multiple requests share the weight load and decode becomes more compute-bound.
Beyond Simple Draft-Verify
The basic draft-verify scheme is just the beginning. Researchers have found increasingly clever ways to speculate:
Tree-Structured Speculation (SpecInfer)
Instead of drafting one linear sequence of K tokens, generate a tree of candidates. At each position, propose the top-2 or top-3 most likely tokens, creating a branching tree of possible continuations. Verify the entire tree in one forward pass using modified attention masks. Even if one branch fails, an alternative branch may succeed — SpecInfer reports boosting per-step acceptance from 52% to 97% by exploring multiple paths simultaneously.
Medusa: Draft Heads Instead of Draft Models
Why use a separate draft model at all? Medusa adds K small MLP "heads" on top of the target model's last hidden layer. Each head predicts the token at position t+k directly — no autoregressive drafting needed. All heads predict in parallel during a single forward pass. Medusa-1 (frozen backbone, train only heads) achieves ~2.2x speedup. Medusa-2 adds LoRA fine-tuning to the backbone for 2.3-3.6x.
EAGLE: Feature-Level Autoregression
The current state of the art. EAGLE's key insight: autoregression at the feature level (the model's second-to-top hidden layer) is much easier than at the token level. Instead of predicting discrete tokens, EAGLE predicts continuous feature vectors, which are then used to estimate token probabilities. EAGLE-3 achieves 3.0-6.5x speedup with dynamic tree structures that adapt based on draft confidence.
Self-Speculative Decoding (LayerSkip)
Use the target model as its own draft by exiting early — skip the last N transformer layers, use the intermediate hidden state as a rough prediction. No separate model needed, no extra memory for a second KV cache. Achieves up to 2.16x speedup with zero additional parameters.
| Method | Extra Params | Speedup | Key Idea |
|---|---|---|---|
| Standard Draft | Small model | 2-3x | Separate small model drafts |
| Medusa | K MLP heads | 2-3.6x | Parallel heads on target |
| EAGLE-3 | Feature predictor | 3-6.5x | Feature-level autoregression |
| LayerSkip | None | 1.5-2.2x | Early exit as draft |
| N-gram | None | 1.5-2x | Match from recent context |
Speculative Decoding in Production
Speculative decoding has moved from research papers to production serving frameworks:
- vLLM — supports EAGLE-1/3, Medusa, and standard draft models. Configurable via
--speculative-modeland--num-speculative-tokens. - TGI (Hugging Face) — Medusa heads and n-gram speculation. The n-gram method is zero-model: it searches the existing context for matching prefixes and uses them as drafts. Excellent for repetitive content like code generation.
- Apple MLX — ReDrafter for Apple Silicon, achieving up to 2.3x speedup on Metal GPUs.
- SGLang & TensorRT-LLM — built-in support for multiple speculation methods with optimized CUDA kernels.
A few practical considerations from production deployments:
- Speculative decoding is most valuable for low-batch, interactive use cases — chatbots, coding assistants, single-user inference. At high batch sizes with many concurrent requests, the memory-bandwidth bottleneck is already amortized across the batch.
- Higher temperature actually increases acceptance rates. A flatter target distribution is easier for the draft to match. Creative writing benefits more than factual Q&A.
- Memory overhead: you need KV cache storage for both the draft and target models (unless using self-speculative methods like LayerSkip). For long-context generation, this can be significant — the same memory tradeoffs we explored in the KV cache post become even more critical when you're maintaining two caches.
- Draft models can be aggressively quantized — they only need to approximate the target, so 4-bit or even 2-bit quantization works well for drafts.
See It in Action
The demo below simulates speculative decoding with a toy vocabulary. Each step uses one target model call. Watch how the draft-verify cycle produces multiple tokens per call, and how draft quality affects the acceptance rate and speedup.
Try It: Speculative Decoding Step by Step
The Inference Optimization Stack
Speculative decoding exploits a fundamental asymmetry: generating one token is memory-bandwidth bound (the GPU barely works), but verifying multiple tokens is compute-bound (the GPU is busy). By drafting cheap guesses and verifying them in batch, we pay for roughly one forward pass and get multiple tokens. And the rejection sampling math guarantees the output is identical to the target model alone — lossless acceleration.
Together with the techniques we've built across this series, we now have a complete inference optimization stack:
- KV Cache — avoid recomputing attention over previous tokens
- Speculative Decoding — avoid wasting memory bandwidth on single-token passes
- Quantization — reduce the memory footprint of the model itself
These three techniques are complementary and multiplicative. KV cache cuts redundant computation, speculative decoding cuts wasted bandwidth, and quantization cuts memory size. Stack all three and a 7B model that once required a data center GPU runs on a laptop.
The full pipeline we've built across this series: tokenize → embed → position → normalize → attend (with KV cache) → FFN → softmax → loss → optimize → decode (with speculative decoding) → fine-tune → quantize. From raw text to optimized inference — every piece built from scratch.
References & Further Reading
- Leviathan et al. — "Fast Inference from Transformers via Speculative Decoding" (ICML 2023) — the original speculative decoding paper, introducing the rejection sampling framework
- Chen et al. — "Accelerating Large Language Model Decoding with Speculative Sampling" (DeepMind, 2023) — independent concurrent work with the same core insight, includes the losslessness proof
- Miao et al. — "SpecInfer: Accelerating LLM Serving with Tree-based Speculative Inference" (ASPLOS 2024) — tree-structured speculation with dramatic acceptance rate improvements
- Cai et al. — "Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads" (ICML 2024) — no draft model needed, heads predict future tokens in parallel
- Li et al. — "EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty" (ICML 2024) — feature-level autoregression for higher acceptance rates
- Li et al. — "EAGLE-3: Scaling Up Speculative Decoding" (2024) — current state of the art, 3-6.5x speedup
- Elhoushi et al. — "LayerSkip: Enabling Early-Exit Inference and Self-Speculative Decoding" (ACL 2024) — the model as its own draft via early exit
- DadOps — KV Cache from Scratch — prefill vs decode, the memory-bandwidth bottleneck
- DadOps — Decoding Strategies from Scratch — temperature, top-k, nucleus sampling
- DadOps — Quantization from Scratch — shrinking model size, enabling faster weight loads