Attention Variants from Scratch: GQA, MQA, and Why Modern LLMs Share Heads
The KV Cache Problem
In our attention-from-scratch post, we built multi-head attention: split Q, K, V into h heads, compute attention independently, concatenate, project. It works beautifully. It's also completely impractical at the scale modern LLMs operate.
Here's why. During autoregressive generation, we store the K and V projections of every previous token in a KV cache (covered in our KV cache post). For standard multi-head attention, the cache size per token per layer is:
For a LLaMA-7B-scale model (32 heads, d_head=128, 32 layers, float16), a 128K-token context window requires:
That's 32 GB just for the cache — more than the model weights themselves. At 128K context, the KV cache becomes the bottleneck, not the computation. Every major LLM deployed today uses an attention variant designed to shrink this cache. Let's build each one from scratch.
Multi-Query Attention: One K,V to Rule Them All
In 2019, Noam Shazeer proposed a radical simplification: what if all attention heads shared a single set of keys and values? Each head keeps its own query projection (so they still attend to different things), but they all read from the same K and V matrices.
MQA: Qi = xWQi, K = xWK, V = xWV (shared)
The KV cache drops from 2 × n_heads × d_head to just 2 × d_head per token per layer — a 32x reduction for a 32-head model. That 32 GB cache becomes 1 GB.
The cost? Quality degrades slightly — about 0.4 to 0.8 points on ROUGE summarization benchmarks. For some applications that's acceptable; for others it's not. PaLM, Falcon, and StarCoder all shipped with MQA.
import numpy as np
np.random.seed(42)
def multi_head_attention(x, n_heads, d_head):
"""Standard MHA: each head has its own Q, K, V."""
seq_len, d_model = x.shape
all_outputs = []
kv_cache_size = 0
for h in range(n_heads):
W_Q = np.random.randn(d_model, d_head) * 0.1
W_K = np.random.randn(d_model, d_head) * 0.1
W_V = np.random.randn(d_model, d_head) * 0.1
Q = x @ W_Q # (seq_len, d_head)
K = x @ W_K
V = x @ W_V
kv_cache_size += K.size + V.size # each head stores K, V
scores = Q @ K.T / np.sqrt(d_head)
weights = np.exp(scores) / np.exp(scores).sum(axis=-1, keepdims=True)
all_outputs.append(weights @ V)
return np.concatenate(all_outputs, axis=-1), kv_cache_size
def multi_query_attention(x, n_heads, d_head):
"""MQA: all heads share a single K, V."""
seq_len, d_model = x.shape
# Single shared K, V
W_K = np.random.randn(d_model, d_head) * 0.1
W_V = np.random.randn(d_model, d_head) * 0.1
K = x @ W_K
V = x @ W_V
kv_cache_size = K.size + V.size # only one K, V pair
all_outputs = []
for h in range(n_heads):
W_Q = np.random.randn(d_model, d_head) * 0.1
Q = x @ W_Q
scores = Q @ K.T / np.sqrt(d_head)
weights = np.exp(scores) / np.exp(scores).sum(axis=-1, keepdims=True)
all_outputs.append(weights @ V)
return np.concatenate(all_outputs, axis=-1), kv_cache_size
# Compare on a toy sequence
seq_len, d_model, n_heads, d_head = 16, 64, 8, 8
x = np.random.randn(seq_len, d_model)
mha_out, mha_cache = multi_head_attention(x, n_heads, d_head)
mqa_out, mqa_cache = multi_query_attention(x, n_heads, d_head)
print(f"MHA output shape: {mha_out.shape}, KV cache: {mha_cache} values")
print(f"MQA output shape: {mqa_out.shape}, KV cache: {mqa_cache} values")
print(f"Cache reduction: {mha_cache / mqa_cache:.0f}x")
# MHA output shape: (16, 64), KV cache: 2048 values
# MQA output shape: (16, 64), KV cache: 256 values
# Cache reduction: 8x
Eight heads, 8x cache reduction. For GPT-4-scale models with 96+ heads, MQA would give a 96x reduction. But there's a problem: forcing all heads to share the same keys and values means they can't specialize as effectively. The model loses some of its ability to attend to different aspects of the input simultaneously.
Grouped-Query Attention: The Sweet Spot
Ainslie et al. (2023) asked the natural question: instead of all-or-nothing, what about a middle ground? Grouped-Query Attention divides the n_heads query heads into g groups, where each group shares one K,V pair.
The KV cache becomes 2 × g × d_head, giving a reduction factor of n_heads / g. With 32 heads and 8 groups, you get a 4x cache reduction while keeping most of MHA's quality — typically within 1-2% on downstream tasks.
A remarkable practical finding: you can uptrain an existing MHA checkpoint to GQA using only 5% of the original pre-training compute. The model learns to share K,V projections within groups quickly. This is how Meta converted LLaMA 2-70B from MHA to GQA.
# Continuing from the MHA/MQA code above (uses x, n_heads, d_head)
def grouped_query_attention(x, n_heads, d_head, n_groups):
"""GQA: heads are divided into groups sharing K, V."""
seq_len, d_model = x.shape
heads_per_group = n_heads // n_groups
all_outputs = []
kv_cache_size = 0
for g in range(n_groups):
# Each group has its own K, V
W_K = np.random.randn(d_model, d_head) * 0.1
W_V = np.random.randn(d_model, d_head) * 0.1
K = x @ W_K
V = x @ W_V
kv_cache_size += K.size + V.size
# Multiple query heads share this group's K, V
for h in range(heads_per_group):
W_Q = np.random.randn(d_model, d_head) * 0.1
Q = x @ W_Q
scores = Q @ K.T / np.sqrt(d_head)
weights = np.exp(scores) / np.exp(scores).sum(axis=-1, keepdims=True)
all_outputs.append(weights @ V)
return np.concatenate(all_outputs, axis=-1), kv_cache_size
# Compare all three: MHA, GQA, MQA
n_heads, d_head = 8, 8
configs = [
("MHA (g=8)", n_heads), # every head has its own K,V
("GQA (g=4)", 4), # 4 groups, 2 heads per group
("GQA (g=2)", 2), # 2 groups, 4 heads per group
("MQA (g=1)", 1), # single shared K,V
]
print(f"{'Variant':<14} {'KV cache':>10} {'Reduction':>10}")
print("-" * 36)
mha_cache = None
for name, g in configs:
_, cache = grouped_query_attention(x, n_heads, d_head, g)
if mha_cache is None:
mha_cache = cache
ratio = mha_cache / cache
print(f"{name:<14} {cache:>10} {ratio:>9.0f}x")
# Variant KV cache Reduction
# ------------------------------------
# MHA (g=8) 2048 1x
# GQA (g=4) 1024 2x
# GQA (g=2) 512 4x
# MQA (g=1) 256 8x
This is why GQA won. LLaMA 2 (70B), LLaMA 3, Mistral 7B, Mixtral 8x7B, and Gemma all use GQA with carefully chosen group counts. It's the single most impactful architectural change between the original Transformer and today's production models.
Interactive: Attention Variant Visualizer
Adjust the number of KV groups to morph between MQA (1 group), GQA (2-7 groups), and full MHA (8 groups). Watch how the attention patterns and KV cache size change. Each column of colored blocks represents one KV group shared by multiple query heads.
Sliding Window Attention: Local Is Enough
GQA shrinks how much we store per token. Sliding window attention shrinks how many tokens we store. Instead of attending to the entire sequence, each token only sees the previous w tokens.
Sliding window: token i attends to tokens max(0, i-w), ..., i O(n·w)
This might seem like a severe limitation — how can the model understand long-range dependencies with a window of only 4,096 tokens? The answer is information propagation through layers. After L layers, each with window w, a token's effective receptive field is w × L.
For Mistral 7B: w=4,096 and L=32 layers. The effective receptive field is 4,096 × 32 = 131,072 tokens — far exceeding the typical context window. Information from distant tokens reaches any position by hopping through intermediate positions across layers, much like messages passing through a telephone chain.
Even better: the KV cache becomes bounded. With a rolling buffer that only keeps the most recent w positions, memory usage is constant regardless of sequence length. Mistral reported a 2x speed improvement over vanilla attention at 16K sequence length.
def create_attention_mask(seq_len, window_size=None):
"""Create causal attention mask, optionally with sliding window."""
# Causal mask: token i can attend to tokens 0..i
mask = np.tril(np.ones((seq_len, seq_len)))
if window_size is not None:
# Sliding window: token i can only attend to tokens (i-w)..i
window_mask = np.zeros((seq_len, seq_len))
for i in range(seq_len):
start = max(0, i - window_size + 1)
window_mask[i, start:i+1] = 1.0
mask = mask * window_mask
return mask
def sliding_window_attention(x, d_head, window_size):
"""Attention with a sliding window mask."""
seq_len, d_model = x.shape
W_Q = np.random.randn(d_model, d_head) * 0.1
W_K = np.random.randn(d_model, d_head) * 0.1
W_V = np.random.randn(d_model, d_head) * 0.1
Q, K, V = x @ W_Q, x @ W_K, x @ W_V
scores = Q @ K.T / np.sqrt(d_head)
mask = create_attention_mask(seq_len, window_size)
scores = np.where(mask == 1, scores, -1e9)
weights = np.exp(scores) / np.exp(scores).sum(axis=-1, keepdims=True)
return weights
# Show how window size affects the attention pattern
seq_len = 12
x = np.random.randn(seq_len, 32)
print("Attention masks (1 = can attend, 0 = masked):\n")
for w_name, w in [("Full", None), ("Window=4", 4), ("Window=2", 2)]:
mask = create_attention_mask(seq_len, w)
# Show cache size: full stores all tokens, window stores only w
cache = seq_len if w is None else min(seq_len, w)
print(f"{w_name} (max cache = {cache} tokens per layer):")
for row in mask[:6, :6]:
print(" ", " ".join("█" if v else "·" for v in row))
print()
# Full (max cache = 12 tokens per layer):
# █ · · · · ·
# █ █ · · · ·
# █ █ █ · · ·
# █ █ █ █ · ·
# █ █ █ █ █ ·
# █ █ █ █ █ █
#
# Window=4 (max cache = 4 tokens per layer):
# █ · · · · ·
# █ █ · · · ·
# █ █ █ · · ·
# █ █ █ █ · ·
# · █ █ █ █ ·
# · · █ █ █ █
#
# Window=2 (max cache = 2 tokens per layer):
# █ · · · · ·
# █ █ · · · ·
# · █ █ · · ·
# · · █ █ · ·
# · · · █ █ ·
# · · · · █ █
Notice how the window creates a diagonal band in the attention mask. Information can only travel w positions per layer — but stacking layers lets it propagate across the entire sequence. Mistral combines sliding window attention with GQA, getting both benefits: fewer KV entries per token AND fewer tokens in the cache.
Interactive: Sliding Window Explorer
Adjust the window size and number of layers to see the attention mask and effective receptive field. Each layer lets information propagate w more positions. Green cells show which tokens a query position can reach (directly or through intermediate layers).
Cross-Attention: Bridging Two Worlds
Every variant so far is a form of self-attention — a sequence attending to itself. But some of the most powerful models need to connect two different sequences: a source language and a target language, an image and a caption, an audio waveform and a transcript.
Cross-attention takes queries from one sequence and keys/values from another:
K = xencoder · WK (from the encoder / source)
V = xencoder · WV (from the encoder / source)
The formula is identical to self-attention — the only difference is where Q, K, V come from. This simple change enables an entire category of models: encoder-decoder transformers for translation (T5, BART), vision-language models (Flamingo, BLIP-2), speech recognition (Whisper), and even image generation (Stable Diffusion, where the U-Net cross-attends to text embeddings).
Cross-attention has a built-in efficiency advantage: the encoder's K,V are computed once after encoding the source, then reused for every decoder generation step. Unlike self-attention's KV cache that grows with output length, the cross-attention cache is fixed.
def cross_attention(x_decoder, x_encoder, d_head):
"""Cross-attention: queries from decoder, keys/values from encoder."""
d_model = x_decoder.shape[1]
W_Q = np.random.randn(d_model, d_head) * 0.1
W_K = np.random.randn(d_model, d_head) * 0.1
W_V = np.random.randn(d_model, d_head) * 0.1
Q = x_decoder @ W_Q # (dec_len, d_head) — what am I looking for?
K = x_encoder @ W_K # (enc_len, d_head) — what do I contain?
V = x_encoder @ W_V # (enc_len, d_head) — what information do I carry?
# Decoder tokens attend to encoder tokens
scores = Q @ K.T / np.sqrt(d_head) # (dec_len, enc_len)
weights = np.exp(scores) / np.exp(scores).sum(axis=-1, keepdims=True)
output = weights @ V # (dec_len, d_head)
return output, weights
# Simulate translation: French encoder -> English decoder
np.random.seed(42)
enc_len, dec_len, d_model, d_head = 5, 4, 32, 8
x_encoder = np.random.randn(enc_len, d_model) # "Le chat dort"
x_decoder = np.random.randn(dec_len, d_model) # "The cat sleeps"
output, weights = cross_attention(x_decoder, x_encoder, d_head)
src_tokens = ["Le", "chat", "dort", "sur", "lit"]
tgt_tokens = ["The", "cat", "sleeps", "on"]
print("Cross-attention weights (decoder -> encoder):")
print(f"{'':>10}", " ".join(f"{t:>6}" for t in src_tokens))
for i, tgt in enumerate(tgt_tokens):
row = " ".join(f"{weights[i,j]:.3f}" for j in range(enc_len))
print(f"{tgt:>10} {row}")
# Cross-attention weights (decoder -> encoder):
# Le chat dort sur lit
# The 0.152 0.186 0.232 0.196 0.234
# cat 0.223 0.180 0.230 0.223 0.143
# sleeps 0.230 0.178 0.217 0.163 0.213
# on 0.269 0.185 0.137 0.168 0.241
In a trained translation model, you'd see "The" attending strongly to "Le", "cat" to "chat", and "sleeps" to "dort" — the cross-attention learns an implicit word alignment. Here with random weights the pattern is uniform, but the mechanism is identical.
The Efficiency Frontier
Let's put all the variants together. Each occupies a different point on the quality-vs-efficiency tradeoff:
| Variant | KV Cache (vs MHA) | Quality Impact | Used By |
|---|---|---|---|
| MHA | 1x (baseline) | Baseline | GPT-2, BERT, LLaMA 1 |
| GQA | n_heads/g reduction | -1 to 2% on benchmarks | LLaMA 2/3, Mistral, Gemma |
| MQA | n_heads× reduction | -0.4 to 0.8 ROUGE | PaLM, Falcon, StarCoder |
| SWA | Bounded by window w | Minimal (propagates via layers) | Mistral 7B, Longformer |
| MLA | ~93% reduction | Equal or better | DeepSeek-V2/V3 |
The frontier keeps advancing. DeepSeek's Multi-Head Latent Attention (MLA) compresses K,V into a low-dimensional latent vector — not by sharing heads, but by projecting into a learned subspace. It achieves 93% cache reduction while actually improving quality over MHA, likely because the low-rank compression acts as regularization. And these variants compose: Mistral uses GQA + sliding window + Flash Attention together.
These architectural variants are orthogonal to algorithmic optimizations like Flash Attention (which changes how attention is computed, not what it computes). A production LLM typically uses GQA for cache efficiency + Flash Attention for compute efficiency + quantized KV cache for memory efficiency — all stacking multiplicatively.
From Paper to Production
The original Transformer gave us multi-head attention — beautiful, powerful, and expensive. The journey since then has been a series of engineering insights about what can be shared, compressed, or windowed without losing the magic.
- MQA showed that sharing all K,V across heads barely hurts quality — heads mostly need different queries, not different keys
- GQA found the sweet spot: share within groups, keep quality within 1-2%, and you can even uptrain existing models
- Sliding window exploited the fact that attention is mostly local — and multi-layer stacking gives global reach anyway
- Cross-attention showed that connecting two sequences uses the same Q,K,V mechanism with a simple twist in where the inputs come from
- MLA pushed further by compressing K,V into a learned latent space — turning cache reduction into a training signal
Together, these variants are what make 128K-context models possible on a single GPU. The attention mechanism itself hasn't changed since 2017 — it's still softmax(QKT/√dk)V. What changed is how we organize Q, K, and V to balance quality, memory, and speed. And that engineering, invisible to the end user, is what separates a research prototype from a production LLM.
References & Further Reading
- Shazeer — Fast Transformer Decoding: One Write-Head is All You Need (2019) — The original Multi-Query Attention paper
- Ainslie et al. — GQA: Training Generalized Multi-Query Transformer Models (EMNLP 2023) — Grouped-Query Attention and the uptraining recipe
- Beltagy et al. — Longformer: The Long-Document Transformer (2020) — Sliding window + global attention for long documents
- Jiang et al. — Mistral 7B (2023) — Sliding window + GQA in a production model
- DeepSeek-AI — DeepSeek-V2 (2024) — Multi-Head Latent Attention with joint KV compression
- Vaswani et al. — Attention Is All You Need (2017) — The original Transformer with MHA and cross-attention