Normalization from Scratch: Why Every Transformer Layer Needs a Reset Button
The Exploding Signal
You've built the full transformer pipeline. You can tokenize input, embed it, add positional encoding, run attention with a KV cache, apply softmax, compute loss, optimize, decode, fine-tune, and quantize. There's one critical piece we've been hand-waving: normalization.
Without it, deep networks are a ticking time bomb. Watch what happens when you stack 50 linear layers and pass a signal through:
import numpy as np
np.random.seed(42)
d = 128
x = np.random.randn(4, d)
print(f"Input: std(x) = {x.std():.4f}")
for i in range(50):
W = np.random.randn(d, d) * 0.15
x = x @ W
if i in [0, 9, 19, 29, 49]:
print(f"Layer {i+1:2d}: std(x) = {x.std():.6e}")
# Input: std(x) = 0.9781
# Layer 1: std(x) = 1.733205e+00
# Layer 10: std(x) = 1.994266e+02
# Layer 20: std(x) = 3.580051e+04
# Layer 30: std(x) = 6.600226e+06
# Layer 50: std(x) = 2.935333e+11
The standard deviation starts at a modest 0.98 and explodes to 293 billion by layer 50. Each matrix multiplication amplifies the signal slightly, and across 50 layers that tiny amplification compounds into catastrophe. These are the kinds of numbers that overflow to inf in float16 and produce NaN gradients during backpropagation.
This is the "internal covariate shift" problem that Sergey Ioffe and Christian Szegedy identified in 2015: as each layer's weights update during training, the distribution of inputs to the next layer shifts. The deeper you go, the more these shifts compound. Training a 50-layer network without normalization is like trying to hit a moving target while standing on another moving target.
The fix is deceptively simple: after each layer, reset the signal. Force the activations back to a well-behaved distribution before the next layer sees them. The question is how.
Batch Normalization — The Original Fix
Ioffe and Szegedy's 2015 paper introduced Batch Normalization (BatchNorm), and the results were striking: networks trained faster, tolerated higher learning rates, and the notorious vanishing/exploding gradient problem was dramatically reduced.
The idea is a four-step recipe applied to each feature independently:
2. Compute variance across the batch: σ² = (1/N) ∑ (xi − μ)²
3. Normalize: x̂ = (x − μ) / √(σ² + ε)
4. Scale and shift: y = γ · x̂ + β
Steps 1–3 are standard statistical normalization — subtract the mean, divide by the standard deviation. The tiny ε (typically 1e-5) prevents division by zero. Step 4 is the clever part: γ (scale) and β (shift) are learnable parameters. This means the network can undo the normalization if that turns out to be optimal. They're initialized to γ=1, β=0 (identity transform), so normalization is the default.
def batch_norm(x, gamma, beta, eps=1e-5):
"""Batch Normalization: normalize across the batch dimension."""
# x shape: (batch_size, features)
mean = x.mean(axis=0, keepdims=True) # Mean per feature
var = x.var(axis=0, keepdims=True) # Variance per feature
x_hat = (x - mean) / np.sqrt(var + eps) # Normalize
return gamma * x_hat + beta # Scale and shift
# Test it
np.random.seed(42)
x = np.random.randn(32, 128) * 5 + 3 # Shifted, scaled input
gamma = np.ones(128) # Learnable scale (init: 1)
beta = np.zeros(128) # Learnable shift (init: 0)
out = batch_norm(x, gamma, beta)
print(f"Before BatchNorm: mean = {x.mean():.4f}, std = {x.std():.4f}")
print(f"After BatchNorm: mean = {out.mean():.4f}, std = {out.std():.4f}")
# Before BatchNorm: mean = 3.0828, std = 4.9848
# After BatchNorm: mean = -0.0000, std = 1.0000
The input had mean 3.08 and standard deviation 4.98 — a shifted, stretched distribution. After BatchNorm, the output is centered at zero with unit variance. Every feature has been individually reset to a standard normal shape.
There's one important detail for deployment: at inference time, you don't have a batch to compute statistics over. So during training, BatchNorm maintains a running exponential moving average of the batch mean and variance. At inference time, these frozen statistics are used instead. This training/inference discrepancy will matter later.
Why BatchNorm Fails for Transformers
BatchNorm was a revolution for convolutional networks, but it hits three walls when you try to use it inside a transformer:
Problem 1: Variable sequence lengths. In a transformer, different sequences in the same batch usually have different lengths. Shorter sequences get padded. When BatchNorm computes statistics across the batch, those padding tokens contaminate the mean and variance with meaningless zeros.
Problem 2: Cross-sequence dependency. BatchNorm's statistics depend on which other sequences happen to be in the batch. This means the same token in the same position can produce different outputs depending on what other documents the model is processing in parallel. That's semantically meaningless — my token representation shouldn't change because someone else's input changed.
Problem 3: Small batches. Transformers are memory-hungry. A 7B model might fit a batch of 4 sequences on a single GPU. With only 4 samples, batch statistics are noisy and unreliable.
Let's see problem 2 in action. Watch what happens when we feed the exact same token through BatchNorm with different batch companions:
np.random.seed(42)
token = np.array([[1.0, 2.0, 3.0, 4.0]]) # Our target token
# Batch 1: token alongside small-valued companions
batch1 = np.vstack([token, np.random.randn(3, 4) * 0.1])
# Batch 2: same token alongside large-valued companions
batch2 = np.vstack([token, np.random.randn(3, 4) * 10.0])
gamma4 = np.ones(4)
beta4 = np.zeros(4)
out1 = batch_norm(batch1, gamma4, beta4)[0] # First row = our token
out2 = batch_norm(batch2, gamma4, beta4)[0]
print(f"Same token: {token[0]}")
print(f"Output (batch 1): [{', '.join(f'{v:.4f}' for v in out1)}]")
print(f"Output (batch 2): [{', '.join(f'{v:.4f}' for v in out2)}]")
print(f"Different outputs? {not np.allclose(out1, out2)}")
# Same token: [1. 2. 3. 4.]
# Output (batch 1): [1.7263, 1.7310, 1.7293, 1.7305]
# Output (batch 2): [-0.1124, 0.6788, 1.0722, 1.5325]
# Different outputs? True
The token [1, 2, 3, 4] produces wildly different normalized outputs depending on its batch companions: [1.73, 1.73, 1.73, 1.73] versus [-0.11, 0.68, 1.07, 1.53]. That's a dealbreaker. Normalization should be deterministic for a given input — I shouldn't need to know what other sequences are in the batch to predict my output.
The solution: normalize each token independently, using its own features instead of borrowing statistics from other samples.
Layer Normalization — Normalize Across Features
Ba, Kiros, and Hinton proposed Layer Normalization in 2016, and the fix is surprisingly simple: instead of computing mean and variance across the batch dimension (axis 0), compute them across the feature dimension (axis −1). Each token normalizes itself, using only its own activation values.
LayerNorm: normalize across features (axis=−1) — "how do this token's features behave?"
def layer_norm(x, gamma, beta, eps=1e-5):
"""Layer Normalization: normalize across the feature dimension."""
# x shape: (batch_size, features) or (batch, seq, features)
mean = x.mean(axis=-1, keepdims=True) # Mean per token
var = x.var(axis=-1, keepdims=True) # Variance per token
x_hat = (x - mean) / np.sqrt(var + eps) # Normalize
return gamma * x_hat + beta # Scale and shift
# Same test: same token with different batch companions
np.random.seed(42)
token = np.array([[1.0, 2.0, 3.0, 4.0]])
batch1 = np.vstack([token, np.random.randn(3, 4) * 0.1])
batch2 = np.vstack([token, np.random.randn(3, 4) * 10.0])
gamma4 = np.ones(4)
beta4 = np.zeros(4)
out1 = layer_norm(batch1, gamma4, beta4)[0]
out2 = layer_norm(batch2, gamma4, beta4)[0]
print(f"Same token: {token[0]}")
print(f"Output (batch 1): [{', '.join(f'{v:.4f}' for v in out1)}]")
print(f"Output (batch 2): [{', '.join(f'{v:.4f}' for v in out2)}]")
print(f"Identical outputs? {np.allclose(out1, out2)}")
# Same token: [1. 2. 3. 4.]
# Output (batch 1): [-1.3416, -0.4472, 0.4472, 1.3416]
# Output (batch 2): [-1.3416, -0.4472, 0.4472, 1.3416]
# Identical outputs? True
The outputs are now identical regardless of batch composition. Each token's normalization depends only on its own features — exactly what we need.
| Property | BatchNorm | LayerNorm |
|---|---|---|
| Normalizes across | Batch (axis=0) | Features (axis=−1) |
| Depends on other samples? | Yes | No |
| Train/inference behavior | Different (running stats) | Identical |
| Handles variable seq lengths? | Poorly | Perfectly |
| Small batch friendly? | No (noisy stats) | Yes (batch-independent) |
LayerNorm became the standard normalization for transformers. The original "Attention Is All You Need" paper used it, GPT used it, BERT used it. For years, it was the default — until someone asked: do we really need all four steps?
RMSNorm — The Elegant Simplification
In 2019, Zhang and Sennrich published a short paper with a bold hypothesis: the re-centering step in LayerNorm (subtracting the mean) isn't actually necessary. What matters is the re-scaling — bringing activations back to a consistent magnitude. They proposed Root Mean Square Layer Normalization (RMSNorm):
RMSNorm(x) = (x / RMS(x)) · γ
That's it. No mean subtraction. No beta parameter. Just divide by the root-mean-square and scale by a learnable γ. The entire operation in NumPy fits in three lines:
def rms_norm(x, gamma, eps=1e-5):
"""RMS Normalization: normalize by root-mean-square (no centering)."""
rms = np.sqrt(np.mean(x ** 2, axis=-1, keepdims=True) + eps)
return (x / rms) * gamma # No beta parameter
# Compare all three on the same input
np.random.seed(42)
x = np.random.randn(4, 128) * 3 + 1 # Shifted, scaled
gamma = np.ones(128)
beta = np.zeros(128)
bn_out = batch_norm(x, gamma, beta)
ln_out = layer_norm(x, gamma, beta)
rms_out = rms_norm(x, gamma)
print("Output statistics (first sample):")
print(f" BatchNorm: mean = {bn_out[0].mean():+.4f}, std = {bn_out[0].std():.4f}")
print(f" LayerNorm: mean = {ln_out[0].mean():+.4f}, std = {ln_out[0].std():.4f}")
print(f" RMSNorm: mean = {rms_out[0].mean():+.4f}, std = {rms_out[0].std():.4f}")
# Output statistics (first sample):
# BatchNorm: mean = -0.0578, std = 0.9449
# LayerNorm: mean = +0.0000, std = 1.0000
# RMSNorm: mean = +0.2730, std = 0.9620
Notice the key difference: LayerNorm forces mean to exactly zero, while RMSNorm's output retains a nonzero mean of +0.27. But all three produce similar standard deviations near 1.0 — the activations are stabilized in all cases. Zhang and Sennrich's insight was that zero-centering is a luxury, not a necessity.
The computational savings are concrete:
- One fewer reduction: No mean computation (one pass over the data instead of two)
- One fewer subtraction: No
(x − mean)step - One fewer parameter: No β vector (128 parameters saved per norm layer)
In a LLaMA-65B model, there are roughly 160 normalization operations per forward pass. Shaving a reduction and a subtraction off each one compounds into a measurable speedup — the original paper reported 7–64% faster normalization depending on the implementation.
Who uses RMSNorm? Virtually every modern open-weight LLM: LLaMA 1/2/3, Mistral, Gemma, Qwen. LayerNorm is still used in older architectures (BERT, GPT-2/3, the original Transformer), but RMSNorm has won the efficiency argument for new models.
Pre-Norm vs Post-Norm — Where You Put It Matters
We've covered what normalization does. Now the architectural question: where does it go? This seemingly minor placement decision turned out to have dramatic consequences for training stability.
The original Transformer (2017) used Post-Norm — normalize after adding the residual:
Pre-Norm: x = x + Sublayer(LayerNorm(x))
GPT-2 (2019) switched to Pre-Norm — normalize before the sublayer, then add the raw output to the residual stream. This seemingly tiny change enabled training much deeper networks without careful learning rate warmup.
The intuition is about gradient highways. In Pre-Norm, the residual connection is "clean" — there's a direct additive path from the output all the way back to the input. During backpropagation, the gradient can flow through this highway unchanged. Each sublayer reads normalized activations and writes a small update back to the raw stream.
In Post-Norm, every gradient must pass through a LayerNorm Jacobian at every layer. The LayerNorm Jacobian is a projection matrix that can distort gradient magnitudes. Stack enough of these projections and gradients near the input can explode or vanish.
Let's see this in practice. We'll pass the same input through 50 residual layers using both architectures, alongside a no-normalization baseline:
np.random.seed(42)
d = 128
gamma_ln = np.ones(d)
beta_ln = np.zeros(d)
def residual_postnorm(x, weights):
"""Post-Norm: x = LayerNorm(x + sublayer(x))"""
stds = []
for W in weights:
sublayer = x @ W
x = layer_norm(x + sublayer, gamma_ln, beta_ln)
stds.append(x.std())
return stds
def residual_prenorm(x, weights):
"""Pre-Norm: x = x + sublayer(LayerNorm(x))"""
stds = []
for W in weights:
sublayer = layer_norm(x, gamma_ln, beta_ln) @ W
x = x + sublayer
stds.append(x.std())
return stds
def residual_nonorm(x, weights):
"""No normalization baseline"""
stds = []
for W in weights:
x = x + x @ W
stds.append(x.std())
return stds
weights = [np.random.randn(d, d) * 0.05 for _ in range(50)]
x = np.random.randn(4, d)
none_stds = residual_nonorm(x.copy(), weights)
post_stds = residual_postnorm(x.copy(), weights)
pre_stds = residual_prenorm(x.copy(), weights)
print("Layer | No Norm | Post-Norm | Pre-Norm")
print("-------|----------------|----------------|----------")
for i in [0, 9, 24, 49]:
print(f" {i+1:3d} | {none_stds[i]:14.4e} | {post_stds[i]:14.4f} | {pre_stds[i]:.4f}")
# Layer | No Norm | Post-Norm | Pre-Norm
# -------|----------------|----------------|----------
# 1 | 1.1552e+00 | 1.0000 | 1.1535
# 10 | 3.8675e+00 | 1.0000 | 2.0211
# 25 | 3.4345e+01 | 1.0000 | 2.9962
# 50 | 1.1243e+03 | 1.0000 | 4.2319
Three very different behaviors:
- No normalization: activations explode to 1,124 — training will collapse
- Post-Norm: activations are locked at exactly 1.0 everywhere (LayerNorm resets them every layer)
- Pre-Norm: activations grow slowly (1.0 → 4.2) as the residual stream accumulates contributions
Post-Norm's constant std=1.0 looks ideal for activations, but it comes at a cost: every gradient must fight through a LayerNorm projection to reach earlier layers. Pre-Norm's slightly growing activations are the feature, not a bug — each sublayer's contribution is preserved in the residual stream, and gradients flow directly back through the clean addition without being projected by normalization.
Xiong et al. (2020) proved this formally using mean field theory: Post-Norm gradient magnitudes depend on the product of LayerNorm Jacobians, which can grow unboundedly near the output at initialization. Pre-Norm gradients decompose into a sum of terms, each bounded by the residual connection.
The "clean residual stream" mental model: in Pre-Norm, the residual stream is an unprocessed accumulator. Each sublayer reads a normalized copy, transforms it, and deposits a small update. The stream itself is never normalized — it just grows. Anthropic's mechanistic interpretability work has popularized this framing as a way to understand what transformer layers actually compute.
The practical consequences are significant:
| Architecture | Normalization | Placement |
|---|---|---|
| Original Transformer (2017) | LayerNorm | Post-Norm |
| GPT-2 / GPT-3 (2019) | LayerNorm | Pre-Norm |
| BERT (2018) | LayerNorm | Post-Norm |
| LLaMA / Mistral / Gemma | RMSNorm | Pre-Norm |
The highlighted row is the modern consensus: Pre-Norm with RMSNorm. It combines the training stability of Pre-Norm with the computational efficiency of RMSNorm. Every major open-weight LLM released since 2023 uses this combination.
The Side-by-Side Showdown
Let's put all three normalization methods head to head on the same input tensor — a realistic transformer shape of (batch=4, seq=16, hidden=128):
import time
np.random.seed(42)
x = np.random.randn(4, 16, 128) # (batch, seq, hidden)
gamma = np.ones(128)
beta = np.zeros(128)
def batch_norm_3d(x, gamma, beta, eps=1e-5):
B, S, D = x.shape
x_flat = x.reshape(-1, D)
mean = x_flat.mean(axis=0, keepdims=True)
var = x_flat.var(axis=0, keepdims=True)
x_hat = (x_flat - mean) / np.sqrt(var + eps)
return (gamma * x_hat + beta).reshape(B, S, D)
def layer_norm_3d(x, gamma, beta, eps=1e-5):
mean = x.mean(axis=-1, keepdims=True)
var = x.var(axis=-1, keepdims=True)
x_hat = (x - mean) / np.sqrt(var + eps)
return gamma * x_hat + beta
def rms_norm_3d(x, gamma, eps=1e-5):
rms = np.sqrt(np.mean(x ** 2, axis=-1, keepdims=True) + eps)
return (x / rms) * gamma
bn_out = batch_norm_3d(x, gamma, beta)
ln_out = layer_norm_3d(x, gamma, beta)
rms_out = rms_norm_3d(x, gamma)
# Timing: 10,000 iterations each
n = 10000
x_bench = np.random.randn(4, 16, 128)
t0 = time.perf_counter()
for _ in range(n): batch_norm_3d(x_bench, gamma, beta)
t_bn = time.perf_counter() - t0
t0 = time.perf_counter()
for _ in range(n): layer_norm_3d(x_bench, gamma, beta)
t_ln = time.perf_counter() - t0
t0 = time.perf_counter()
for _ in range(n): rms_norm_3d(x_bench, gamma)
t_rms = time.perf_counter() - t0
fastest = min(t_bn, t_ln, t_rms)
print(f"Input shape: (4, 16, 128)")
print(f"{'Metric':<20} {'BatchNorm':>12} {'LayerNorm':>12} {'RMSNorm':>12}")
print("-" * 60)
print(f"{'Output mean':<20} {bn_out.mean():>12.4f} {ln_out.mean():>12.4f} {rms_out.mean():>12.4f}")
print(f"{'Output std':<20} {bn_out.std():>12.4f} {ln_out.std():>12.4f} {rms_out.std():>12.4f}")
print(f"{'Parameters':<20} {'256':>12} {'256':>12} {'128':>12}")
print(f"{'Time (10K iters)':<20} {t_bn:>11.1f}s {t_ln:>11.1f}s {t_rms:>11.1f}s")
print(f"{'Relative speed':<20} {t_bn/fastest:>11.1f}x {t_ln/fastest:>11.1f}x {t_rms/fastest:>11.1f}x")
# Input shape: (4, 16, 128)
# Metric BatchNorm LayerNorm RMSNorm
# ------------------------------------------------------------
# Output mean 0.0000 -0.0000 -0.0037
# Output std 1.0000 1.0000 1.0000
# Parameters 256 256 128
# Time (10K iters) ~2.0s ~2.0s ~0.8s
# Relative speed ~2.5x ~2.4x ~1.0x
All three normalize effectively — standard deviation is 1.0 across the board. But RMSNorm is the clear winner on efficiency: roughly 2.5× faster than both BatchNorm and LayerNorm, with half the parameters. When you're running 160 normalization operations per forward pass in a 65-billion-parameter model, that 2.5× speedup is not academic — it's the difference between a training run finishing in weeks versus months.
RMSNorm in Practice — The LLaMA Block
Let's see how all of this fits together in a real architecture. Here's a minimal implementation of a LLaMA-style transformer block — the exact pattern used by LLaMA, Mistral, and Gemma:
def softmax(x):
x = x - x.max(axis=-1, keepdims=True)
e = np.exp(x)
return e / e.sum(axis=-1, keepdims=True)
def swiglu(x, W_gate, W_up, W_down):
"""SwiGLU activation: the FFN used in LLaMA."""
gate = x @ W_gate
up = x @ W_up
silu = gate * (1.0 / (1.0 + np.exp(-gate))) # SiLU = x * sigmoid(x)
return (silu * up) @ W_down
def multi_head_attention(x, Wq, Wk, Wv, Wo, n_heads):
"""Simplified multi-head attention."""
B, S, D = x.shape
head_dim = D // n_heads
Q = (x @ Wq).reshape(B, S, n_heads, head_dim).transpose(0, 2, 1, 3)
K = (x @ Wk).reshape(B, S, n_heads, head_dim).transpose(0, 2, 1, 3)
V = (x @ Wv).reshape(B, S, n_heads, head_dim).transpose(0, 2, 1, 3)
scores = (Q @ K.transpose(0, 1, 3, 2)) / np.sqrt(head_dim)
attn = softmax(scores)
out = (attn @ V).transpose(0, 2, 1, 3).reshape(B, S, D)
return out @ Wo
class LLaMABlock:
"""A single LLaMA-style transformer block: Pre-Norm RMSNorm."""
def __init__(self, d_model, n_heads, d_ff):
scale = 0.02
self.gamma_attn = np.ones(d_model)
self.gamma_ffn = np.ones(d_model)
self.Wq = np.random.randn(d_model, d_model) * scale
self.Wk = np.random.randn(d_model, d_model) * scale
self.Wv = np.random.randn(d_model, d_model) * scale
self.Wo = np.random.randn(d_model, d_model) * scale
self.W_gate = np.random.randn(d_model, d_ff) * scale
self.W_up = np.random.randn(d_model, d_ff) * scale
self.W_down = np.random.randn(d_ff, d_model) * scale
self.n_heads = n_heads
def forward(self, x):
# Pre-Norm attention: normalize THEN attend, add residual
h = rms_norm(x, self.gamma_attn)
h = multi_head_attention(h, self.Wq, self.Wk, self.Wv,
self.Wo, self.n_heads)
x = x + h # Clean residual connection
# Pre-Norm FFN: normalize THEN transform, add residual
h = rms_norm(x, self.gamma_ffn)
h = swiglu(h, self.W_gate, self.W_up, self.W_down)
x = x + h # Clean residual connection
return x
# Build a 4-block model and process a sequence
np.random.seed(42)
d_model, n_heads, d_ff = 64, 4, 172 # ~2.7x expansion (LLaMA ratio)
blocks = [LLaMABlock(d_model, n_heads, d_ff) for _ in range(4)]
x = np.random.randn(1, 8, d_model) # batch=1, seq_len=8, dim=64
print(f"Input: shape = {x.shape}, std = {x.std():.4f}")
for i, block in enumerate(blocks):
x = block.forward(x)
print(f"Block {i+1}: shape = {x.shape}, std = {x.std():.4f}, "
f"mean = {x.mean():.4f}, max|x| = {np.abs(x).max():.4f}")
# Input: shape = (1, 8, 64), std = 0.9957
# Block 1: shape = (1, 8, 64), std = 0.9955, mean = -0.0258, max|x| = 3.5693
# Block 2: shape = (1, 8, 64), std = 0.9959, mean = -0.0268, max|x| = 3.5751
# Block 3: shape = (1, 8, 64), std = 0.9958, mean = -0.0263, max|x| = 3.5661
# Block 4: shape = (1, 8, 64), std = 0.9956, mean = -0.0268, max|x| = 3.5500
Look at those standard deviations: 0.9957 → 0.9955 → 0.9959 → 0.9958 → 0.9956. Across four blocks of attention and feed-forward transformations, the activations barely budge. No explosion, no vanishing — just rock-solid stability. This is the Pre-Norm RMSNorm architecture doing exactly what it was designed to do.
Notice the pattern in each block: RMSNorm normalizes the input before the sublayer reads it, then the sublayer's output is added directly to the unnormalized residual stream. The residual stream is the backbone — it carries the signal forward intact while each sublayer contributes a small, well-conditioned update.
Try It: Normalization Signal Flow Visualizer
Watch activations flow through stacked layers. Toggle normalization on/off and see the difference between stability and explosion.
The Full Picture
We started with a ticking time bomb — activations exploding to 293 billion — and ended with a LLaMA block holding steady at std ≈ 0.996 across four layers of attention and feed-forward transformations. The journey from BatchNorm to LayerNorm to RMSNorm mirrors the evolution of deep learning itself: each step stripped away unnecessary complexity while preserving what actually matters.
The modern transformer block is now settled on a simple recipe: Pre-Norm RMSNorm. Normalize before each sublayer, use the residual stream as a clean highway, and skip the mean subtraction that LayerNorm insisted on. It's fewer operations, fewer parameters, and empirically just as good. With this piece in place, our elementary pipeline is complete: tokenize → embed → position → normalize → attend → softmax → loss → optimize → decode → fine-tune → quantize.
What We Didn't Cover
Normalization is a surprisingly deep topic. Here's what we left on the table:
- GroupNorm (Wu & He, 2018) — a compromise between BatchNorm and LayerNorm that divides features into groups and normalizes within each group. Popular in computer vision when batch sizes are tiny.
- InstanceNorm — normalizes each sample and each channel independently. The secret ingredient behind neural style transfer.
- QK-Norm — normalizing the Query and Key vectors inside attention, used in some modern architectures to prevent attention logit growth in very deep models.
- DeepNorm (Wang et al., 2022) — a scaling trick that tames Post-Norm for extremely deep transformers (1,000+ layers) by carefully scaling residual connections.
References & Further Reading
- Ioffe & Szegedy — "Batch Normalization: Accelerating Deep Network Training" (2015) — the paper that started it all, introducing BatchNorm and its dramatic training speedups
- Ba, Kiros & Hinton — "Layer Normalization" (2016) — the batch-independent alternative that became the transformer default
- Zhang & Sennrich — "Root Mean Square Layer Normalization" (2019) — the elegant simplification now used in LLaMA, Mistral, and most modern LLMs
- Xiong et al. — "On Layer Normalization in the Transformer Architecture" (2020) — the mean-field theory proof that Pre-Norm enables stable gradients
- Touvron et al. — "LLaMA: Open and Efficient Foundation Language Models" (2023) — the architecture that codified Pre-Norm RMSNorm as the modern standard
- Elhage et al. — "A Mathematical Framework for Transformer Circuits" (Anthropic, 2021) — the "residual stream" mental model for understanding Pre-Norm transformers
- Previous DadOps elementary posts: Attention from Scratch (where LayerNorm appeared without explanation), Optimizers from Scratch (normalization affects optimizer dynamics), Loss Functions (gradient pathologies that normalization prevents), LoRA from Scratch (normalization layers are typically frozen during fine-tuning)