KV Cache from Scratch: Why LLMs Don't Recompute Everything
The Hidden Cost of Token-by-Token
Your 7B parameter model takes 14 GB of VRAM just to hold its weights. Generate a long enough response, and a mysterious cache eats another 14 GB. What cache? Where did it come from? And why is it so hungry?
If you've followed along with our attention from scratch post, you know how attention works: project your input into queries, keys, and values, compute dot products, apply softmax, and produce a weighted sum. Clean and elegant. But there's a detail that most tutorials gloss over.
LLMs generate text one token at a time. To produce token number 500, the model needs to attend over all 499 previous tokens. Then to produce token 501, it attends over all 500. Then 501 for the next one. See the problem? A naive implementation recomputes the key and value projections for every previous token at every step. That's a staggering amount of wasted work.
The KV cache is the fix. It's conceptually simple — cache the stuff that doesn't change — but it's also the single most important optimization in LLM inference. Without it, the chatbots we use every day would be impractically slow. With it, generation speed becomes bounded by memory bandwidth rather than raw compute.
Let's build it from scratch and watch it work.
The Naive Approach: Recomputing Everything
First, let's set up a minimal single-head attention layer. We'll use small dimensions so we can trace every computation by hand:
import numpy as np
np.random.seed(42)
# Tiny model: 1 layer, 1 head, embed_dim=8, head_dim=8
d_model = 8
d_head = 8
# Random projection matrices (normally these are learned)
W_q = np.random.randn(d_model, d_head) * 0.1
W_k = np.random.randn(d_model, d_head) * 0.1
W_v = np.random.randn(d_model, d_head) * 0.1
def attention(Q, K, V):
"""Standard scaled dot-product attention."""
scores = Q @ K.T / np.sqrt(d_head)
# Causal mask: each position can only attend to earlier positions
seq_len = scores.shape[0]
mask = np.triu(np.ones((seq_len, seq_len)), k=1) * -1e9
scores = scores + mask
weights = np.exp(scores - scores.max(axis=-1, keepdims=True))
weights = weights / weights.sum(axis=-1, keepdims=True)
return weights @ V
Now here's the naive generation loop. At every step, we recompute Q, K, and V for the entire sequence — including all the tokens we already processed:
def naive_generate(embeddings, num_new_tokens):
"""Generate tokens by recomputing full attention each step."""
seq = list(embeddings) # start with prompt embeddings
flops = 0
for step in range(num_new_tokens):
x = np.array(seq) # shape: (current_len, d_model)
current_len = len(seq)
# Project ALL tokens — even the ones we projected last step
Q = x @ W_q # (current_len, d_head)
K = x @ W_k # (current_len, d_head)
V = x @ W_v # (current_len, d_head)
flops += 3 * current_len # 3 projections × current_len tokens
out = attention(Q, K, V) # (current_len, d_head)
# Use the last position's output as the "next token" embedding
next_token = out[-1]
seq.append(next_token)
return np.array(seq), flops
Look at what happens at step 100: we project all 100 tokens to get K and V, even though the K and V for tokens 0 through 98 are identical to what we computed at step 99. At step 200, we recompute all 200. The projection cost alone grows as 1 + 2 + 3 + ... + n = n(n+1)/2. That's O(n²) total work for n tokens — and we haven't even counted the attention matrix yet.
At every generation step, the keys and values for all previous tokens are exactly the same as last step. We're recomputing them for nothing.
The Key Insight: Cache What Doesn't Change
Here's the crucial observation. When the model processes token at position t, its key and value projections depend only on that token's embedding:
Vt = embeddingt × WV
Nothing about token 500's key depends on token 501. Once computed, Kt and Vt are immutable. So instead of recomputing them every step, we cache them and reuse them forever.
The KV-cached generation loop only projects the new token at each step, then appends its K and V to a growing cache:
def cached_generate(embeddings, num_new_tokens):
"""Generate tokens with KV caching — only project the new token."""
# Prefill: process the entire prompt at once
prompt = np.array(embeddings)
K_cache = prompt @ W_k # (prompt_len, d_head)
V_cache = prompt @ W_v # (prompt_len, d_head)
Q_all = prompt @ W_q
flops = 3 * len(embeddings) # initial projections
out = attention(Q_all, K_cache, V_cache)
next_emb = out[-1]
seq = list(embeddings) + [next_emb]
for step in range(num_new_tokens - 1):
x_new = next_emb.reshape(1, -1) # (1, d_model)
# Project ONLY the new token — one vector, not the whole sequence
q_new = x_new @ W_q # (1, d_head)
k_new = x_new @ W_k # (1, d_head)
v_new = x_new @ W_v # (1, d_head)
flops += 3 # 3 projections × 1 token
# Append to cache
K_cache = np.vstack([K_cache, k_new]) # grows by 1 row
V_cache = np.vstack([V_cache, v_new])
# Attention: new query against ALL cached keys and values
scores = q_new @ K_cache.T / np.sqrt(d_head) # (1, cache_len)
weights = np.exp(scores - scores.max(axis=-1, keepdims=True))
weights = weights / weights.sum(axis=-1, keepdims=True)
next_emb = (weights @ V_cache).flatten()
seq.append(next_emb)
return np.array(seq), flops
The difference is dramatic. Let's count projections for generating 100 tokens after a 10-token prompt:
- Naive: Projects 10 + 11 + 12 + ... + 109 = 5,950 token projections
- Cached: Projects 10 (prefill) + 99 × 1 = 109 token projections
That's a 54× reduction in projection work. And the gap widens as sequences grow longer: for 1000 new tokens it's 509,500 vs 1,009 — a 505× difference.
Both approaches produce identical output. The KV cache doesn't approximate anything — it's mathematically exact. We're just not computing things we already know.
Prefill vs Decode: Two Phases of Inference
You might have noticed our cached version does something different with the prompt than with generated tokens. This reflects how real LLM inference works — it naturally splits into two phases:
Prefill processes the entire prompt in a single parallel pass. All prompt tokens are known upfront, so we can compute their Q, K, and V projections in one matrix multiply and fill the cache in bulk. This phase is compute-bound — it's essentially the same operation as training.
Decode generates tokens one at a time. Each step computes Q, K, V for just the new token, appends K and V to the cache, then does attention against the full cache. This phase is memory-bandwidth-bound — the arithmetic is small (one query against the cache), but loading the entire cache from GPU memory for each token is expensive.
def prefill(prompt_embeddings):
"""Process the full prompt in parallel, return the KV cache."""
x = np.array(prompt_embeddings)
K_cache = x @ W_k # (prompt_len, d_head)
V_cache = x @ W_v # (prompt_len, d_head)
Q = x @ W_q
out = attention(Q, K_cache, V_cache)
return K_cache, V_cache, out[-1]
def decode_step(new_embedding, K_cache, V_cache):
"""Generate one token using the KV cache."""
x = new_embedding.reshape(1, -1)
q = x @ W_q
k = x @ W_k
v = x @ W_v
K_cache = np.vstack([K_cache, k])
V_cache = np.vstack([V_cache, v])
scores = q @ K_cache.T / np.sqrt(d_head)
weights = np.exp(scores - scores.max(axis=-1, keepdims=True))
weights = weights / weights.sum(axis=-1, keepdims=True)
output = (weights @ V_cache).flatten()
return output, K_cache, V_cache
# Usage: clean separation of phases
prompt = [np.random.randn(d_model) for _ in range(10)]
K, V, last_out = prefill(prompt)
for _ in range(50):
last_out, K, V = decode_step(last_out, K, V)
This two-phase split has a profound consequence: during decode, the bottleneck shifts from how fast can we multiply matrices to how fast can we read the cache from memory. The GPU's arithmetic units sit partially idle while waiting for the KV cache to stream from HBM (high-bandwidth memory) into the compute cores. This is why tricks to shrink the cache matter so much — less cache means less data to move.
The Memory Problem: When the Cache Outgrows the Model
The KV cache saves enormous compute, but it pays for that savings with memory. The formula for total cache size is:
The factor of 2 accounts for both K and V. Let's plug in real numbers for LLaMA 2 7B (32 layers, 32 heads, head_dim = 128, FP16):
Half a megabyte per token. That adds up fast:
| Sequence Length | Batch 1 | Batch 8 | Batch 32 |
|---|---|---|---|
| 512 tokens | 256 MB | 2 GB | 8 GB |
| 2,048 tokens | 1 GB | 8 GB | 32 GB |
| 4,096 tokens | 2 GB | 16 GB | 64 GB |
| 28,672 tokens | 14 GB | 112 GB | 448 GB |
That highlighted row is where the KV cache equals the model's own weight footprint. At 28K tokens, the cache and the model each take 14 GB — you need twice the memory just for the cache. Modern models with 128K+ context windows make this problem much worse.
This is the core tension in LLM inference: the KV cache eliminates redundant compute, but it creates a memory wall. Every token you generate makes the cache bigger, and you can never throw it away (because the next token might attend to any previous one). Something has to give.
Shrinking the Cache: Multi-Query and Grouped-Query Attention
In standard Multi-Head Attention (MHA), each of the h attention heads has its own Q, K, and V projections. That means the KV cache stores h separate K vectors and h separate V vectors per token, per layer. For a 32-head model, that's 32 sets of keys and 32 sets of values.
Noam Shazeer asked a provocative question in 2019: do all those heads really need their own keys and values?
Multi-Query Attention (MQA)
MQA takes the extreme position: all query heads share a single set of keys and a single set of values. Each head still has its own Q projection (so they attend to different things), but they all read from the same K and V. The cache shrinks by a factor of h — for a 32-head model, that's 32× smaller.
The downside? Quality takes a hit. With every head forced to share keys and values, the model loses some of its ability to attend to different aspects of the input simultaneously.
Grouped-Query Attention (GQA)
GQA, introduced by Ainslie et al. in 2023, finds the sweet spot. Instead of one KV set (MQA) or h KV sets (MHA), it uses G groups where G is between 1 and h. Each group of query heads shares one set of keys and values:
def grouped_query_attention_with_cache(x_new, K_cache, V_cache,
W_q_heads, W_k_groups, W_v_groups,
n_heads=32, n_kv_groups=8):
"""
GQA: 32 query heads share 8 KV groups (4 query heads per group).
Cache stores only 8 K and 8 V vectors per token, not 32.
"""
head_dim = W_k_groups[0].shape[1]
heads_per_group = n_heads // n_kv_groups # 32 / 8 = 4
# Project new token's K and V — only n_kv_groups projections, not n_heads
k_groups = [x_new @ W_k_groups[g] for g in range(n_kv_groups)]
v_groups = [x_new @ W_v_groups[g] for g in range(n_kv_groups)]
# Append to cache (one K, V per group)
new_K_cache = [np.vstack([K_cache[g], k_groups[g].reshape(1, -1)])
for g in range(n_kv_groups)]
new_V_cache = [np.vstack([V_cache[g], v_groups[g].reshape(1, -1)])
for g in range(n_kv_groups)]
# Each query head attends using its group's cached K and V
head_outputs = []
for h in range(n_heads):
g = h // heads_per_group # which KV group this head uses
q_h = x_new @ W_q_heads[h] # (1, head_dim)
q_h = q_h.reshape(1, -1)
scores = q_h @ new_K_cache[g].T / np.sqrt(head_dim)
weights = np.exp(scores - scores.max(axis=-1, keepdims=True))
weights = weights / weights.sum(axis=-1, keepdims=True)
head_outputs.append((weights @ new_V_cache[g]).flatten())
output = np.concatenate(head_outputs)
return output, new_K_cache, new_V_cache
Here's how the three approaches compare for real models:
| Model | Attention | Query Heads | KV Heads | Cache / Token |
|---|---|---|---|---|
| LLaMA 2 7B | MHA | 32 | 32 | 0.50 MB |
| Mistral 7B | GQA-8 | 32 | 8 | 0.125 MB |
| LLaMA 2 70B | GQA-8 | 64 | 8 | 0.31 MB |
| Hypothetical MQA | MQA | 32 | 1 | 0.016 MB |
Mistral 7B uses GQA with 8 KV heads instead of 32 — a 4× reduction in cache size compared to LLaMA 2 7B. That's the difference between 2 GB and 512 MB at 4K context. GQA won the industry because it achieves quality close to full MHA while keeping inference memory close to MQA.
RoPE and the Cache: Why Position Encoding Plays Nice
If you've read our positional encoding post, you know that RoPE (Rotary Position Embeddings) encodes position by rotating Q and K vectors in 2D subspaces. An important question arises: where does RoPE fit with the KV cache?
The answer is elegant. RoPE is applied before caching:
def decode_step_with_rope(new_emb, position, K_cache, V_cache):
"""RoPE is applied to K before it enters the cache."""
q = new_emb @ W_q
k = new_emb @ W_k
v = new_emb @ W_v
# Apply RoPE rotation based on this token's absolute position
q = apply_rope(q, position) # rotate query by position angle
k = apply_rope(k, position) # rotate key by position angle
# Cache the ROTATED key — position is baked in permanently
K_cache = np.vstack([K_cache, k.reshape(1, -1)])
V_cache = np.vstack([V_cache, v.reshape(1, -1)]) # V is NOT rotated
# Attention: rotated query against rotated keys
# The dot product Q_s · K_t naturally captures relative distance (s - t)
scores = q.reshape(1, -1) @ K_cache.T / np.sqrt(d_head)
weights = np.exp(scores - scores.max(axis=-1, keepdims=True))
weights = weights / weights.sum(axis=-1, keepdims=True)
return (weights @ V_cache).flatten(), K_cache, V_cache
This works because each token's rotation depends only on its own position. The key for token 42 is always rotated by the same angle, regardless of how many tokens come after it. When query at position 100 attends to this cached key, the dot product naturally captures the relative distance (100 − 42 = 58) through the rotation algebra — no re-rotation needed.
Note that RoPE is applied to K but not to V. Values participate in a weighted sum, not a dot product — they don't need positional information.
See It in Action
Try It: Naive vs KV-Cached Generation
Watch how the two approaches differ step by step. Each cell represents a K or V projection for one token. Green = computed this step. Gray = reused from cache. Red = recomputed wastefully.
The Frontier: Production Optimizations
The basic KV cache we built is exactly what early transformer implementations used. But production systems go further:
PagedAttention (used by vLLM) borrows the virtual memory concept from operating systems. Instead of pre-allocating one contiguous block per sequence, it divides the cache into fixed-size pages (e.g., 16 tokens each) and allocates them on demand. A block table maps logical positions to physical memory — just like an OS page table. This reduces memory waste from 60–80% down to under 4%, enabling 2–4× more throughput.
KV Cache Quantization stores cached keys and values in lower precision. FP8 halves the cache with negligible quality loss. INT4 cuts it by 4× with only slight degradation. Since the cache is read-only during decode (keys and values don't change once cached), quantization is particularly safe here — there's no error accumulation from repeated operations.
Sliding Window Attention, used by Mistral, limits each token to attending over only the last W tokens (e.g., W = 4096). The cache never grows beyond W entries, putting a hard cap on memory. Information from beyond the window still propagates through the layers — layer 2's window overlaps with layer 1's, creating an effective receptive field much larger than W.
Each of these deserves its own deep dive, but they all build on the same foundation we just built: cache the keys and values, and be smart about how you store them.
The Pipeline So Far
The KV cache is the bridge between "attention works in theory" and "attention is fast enough to use in practice." It's a purely mechanical optimization — no approximation, no quality loss — that eliminates O(n²) redundant computation by caching O(n) immutable values.
Our elementary series pipeline now looks like this:
tokenize → embed → position → attend (with KV cache) → softmax → loss → optimize → decode
The elegant math of attention meets the hard reality of memory limits. GQA shrinks the cache. PagedAttention manages it. Quantization compresses it. And still, for every new token, the cache grows by another row. That tension between compute savings and memory cost is the beating heart of LLM inference engineering.
References & Further Reading
- Vaswani et al. — "Attention Is All You Need" (2017) — The original transformer paper where K, V projections were introduced
- Noam Shazeer — "Fast Transformer Decoding: One Write-Head is All You Need" (2019) — The Multi-Query Attention paper that started the cache reduction revolution
- Ainslie et al. — "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (2023) — Grouped-Query Attention, the practical middle ground used by LLaMA 2 and Mistral
- Kwon et al. — "Efficient Memory Management for Large Language Model Serving with PagedAttention" (2023) — The vLLM paper that brought OS-style virtual memory to KV caches
- Su et al. — "RoFormer: Enhanced Transformer with Rotary Position Embedding" (2021) — RoPE, the position encoding scheme that caches cleanly
- DadOps — Attention from Scratch — Where we built Q, K, V attention from zero
- DadOps — Decoding Strategies from Scratch — Token-by-token generation, where the KV cache lives
- DadOps — Positional Encoding from Scratch — RoPE and how it interacts with caching