Flash Attention from Scratch: Why the Fastest Attention Algorithm Never Materializes the Attention Matrix
The Memory Wall
Standard attention has a dirty secret. It's not slow because of math — it's slow because of memory.
In our attention from scratch post, we built dot-product attention: project inputs to Q, K, V, compute scores as QKT, apply softmax, multiply by V. Clean and correct. But there's a problem hiding in plain sight. That QKT matrix has N×N entries, where N is the sequence length. For a 32K-token prompt, that's over 1 billion entries per attention head, per layer. And the model writes every single one to memory, reads them all back for softmax, then reads them again to multiply by V.
The GPU isn't struggling with arithmetic. It's struggling with a traffic jam.
Flash Attention, introduced by Tri Dao in 2022, eliminates this bottleneck entirely. It computes exact attention — not an approximation, the very same result, bit for bit — without ever materializing the full N×N matrix. The trick is restructuring the computation to respect the GPU's memory hierarchy: keep small blocks in fast on-chip memory (SRAM), never write the giant intermediate matrix to slow off-chip memory (HBM).
In this post, we'll build Flash Attention from scratch in NumPy. We'll understand why GPUs have a memory wall, learn the online softmax trick that makes tiling possible, implement the full algorithm, and prove it produces identical results to naive attention. By the end, you'll understand the single most important systems optimization in modern deep learning — the algorithm that made 100K+ token contexts possible.
Why Attention Is Memory-Bound
To understand Flash Attention, you need to understand the GPU memory hierarchy. A modern GPU like the NVIDIA H100 has two kinds of memory:
- HBM (High Bandwidth Memory) — 80 GB, ~3.35 TB/s bandwidth. This is where your tensors live. It's "high bandwidth" compared to CPU RAM, but it's the slow memory from the GPU's perspective.
- SRAM (on-chip memory) — ~50 MB total across all streaming multiprocessors, but with ~33 TB/s bandwidth. That's roughly 10× faster than HBM, but 1600× smaller.
Every GPU operation works the same way: load data from HBM into SRAM, do computation, write results back to HBM. The computation itself is cheap — modern tensor cores blaze through matrix multiplies. The bottleneck is the data movement.
Now look at what standard attention does:
Step 2: Load S from HBM → Compute P = softmax(S) → Write P (N×N) to HBM
Step 3: Load P, V from HBM → Compute O = PV → Write O (N×d) to HBM
Count the damage: Q and K are N×d each. S and P are N×N each. For a typical attention head with N=2048 and d=64, the Q and K matrices have 131K entries each, but S and P have 4.2 million entries each. The N×N intermediates dominate everything.
Let's make this concrete by instrumenting naive attention with a memory access counter:
import numpy as np
def naive_attention_counted(Q, K, V):
"""Standard attention with HBM access counting."""
N, d = Q.shape
hbm_reads = 0
hbm_writes = 0
# Step 1: Compute S = Q @ K^T
# Read Q (N*d) and K (N*d) from HBM
hbm_reads += 2 * N * d
S = Q @ K.T / np.sqrt(d)
# Write S (N*N) to HBM
hbm_writes += N * N
# Step 2: Softmax
# Read S (N*N) from HBM
hbm_reads += N * N
P = np.exp(S - S.max(axis=-1, keepdims=True))
P = P / P.sum(axis=-1, keepdims=True)
# Write P (N*N) to HBM
hbm_writes += N * N
# Step 3: Output = P @ V
# Read P (N*N) and V (N*d) from HBM
hbm_reads += N * N + N * d
O = P @ V
# Write O (N*d) to HBM
hbm_writes += N * d
total_accesses = hbm_reads + hbm_writes
return O, {
'hbm_reads': hbm_reads,
'hbm_writes': hbm_writes,
'total': total_accesses,
'n_squared_terms': 4 * N * N, # S write + S read + P write + P read
'n_d_terms': 4 * N * d # Q, K reads + V read + O write
}
# Example: N=2048, d=64
np.random.seed(42)
N, d = 2048, 64
Q = np.random.randn(N, d).astype(np.float32)
K = np.random.randn(N, d).astype(np.float32)
V = np.random.randn(N, d).astype(np.float32)
O_naive, stats = naive_attention_counted(Q, K, V)
print(f"Sequence length: {N}, Head dim: {d}")
print(f"HBM reads: {stats['hbm_reads']:>12,}")
print(f"HBM writes: {stats['hbm_writes']:>12,}")
print(f"Total HBM: {stats['total']:>12,}")
print(f"N² terms: {stats['n_squared_terms']:>12,} ({100*stats['n_squared_terms']/stats['total']:.0f}%)")
print(f"Nd terms: {stats['n_d_terms']:>12,} ({100*stats['n_d_terms']/stats['total']:.0f}%)")
Output:
Sequence length: 2048, Head dim: 64
HBM reads: 8,781,824
HBM writes: 8,519,680
Total HBM: 17,301,504
N² terms: 16,777,216 (97%)
Nd terms: 524,288 (3%)
97% of all memory traffic comes from the N×N intermediate matrices. The actual input and output (the Nd terms) are just 3% of the traffic. As N grows, this gets dramatically worse — memory traffic grows quadratically while the useful input/output grows linearly.
The GPU's tensor cores can churn through the actual matrix multiplications at incredible speed. But they spend most of their time waiting for data to shuttle between HBM and SRAM. This is what it means to be memory-bound: the computation is fast, but the plumbing is slow.
The Online Softmax Trick
Flash Attention's strategy is to process the attention matrix in tiles — small blocks that fit entirely in SRAM. But there's an immediate problem: softmax is a global operation.
Standard softmax over a row of scores requires three passes:
- Find the maximum value across all N scores (for numerical stability)
- Subtract the max, exponentiate all N values, and sum them
- Divide each value by the sum
You can't compute the final softmax for a tile of keys without knowing if a later tile contains a larger value. Or can you?
The online softmax algorithm, described by Milakov and Gimelshein in 2018, solves this elegantly. Instead of waiting to see all values, it processes them one block at a time, maintaining a running maximum and a running sum. When a new block arrives with a larger value, it corrects the previous partial results.
Here's the idea. Suppose you've processed block 1 and found a local max m1 and unnormalized sum l1 = ∑ exp(x - m1). Now block 2 arrives with a local max m2. If m2 > m1, all your previous exponentials were computed with the wrong baseline. The fix: multiply every previous term by exp(m1 - m2) to shift them to the new baseline. This rescaling is exact — no approximation.
def softmax_standard(x):
"""Standard 3-pass softmax."""
m = np.max(x) # Pass 1: find max
e = np.exp(x - m) # Pass 2: subtract max, exponentiate
return e / np.sum(e) # Pass 3: normalize
def softmax_online(x, block_size=3):
"""Online softmax: processes values in blocks, maintaining running stats."""
N = len(x)
m = -np.inf # running maximum
l = 0.0 # running sum of exp(x_i - m)
# First pass: compute global max and sum in blocks
for start in range(0, N, block_size):
block = x[start : start + block_size]
m_block = np.max(block)
# If this block has a new max, rescale previous sum
m_new = max(m, m_block)
l = l * np.exp(m - m_new) + np.sum(np.exp(block - m_new))
m = m_new
# Second pass: compute final softmax values
return np.exp(x - m) / l
# Demonstrate equivalence
x = np.array([1.0, 2.0, 3.0, 6.0, 2.0, 1.0])
standard = softmax_standard(x)
online = softmax_online(x, block_size=3)
print("Input:", x)
print("Standard softmax:", np.round(standard, 6))
print("Online softmax: ", np.round(online, 6))
print("Max difference: ", np.max(np.abs(standard - online)))
Output:
Input: [1. 2. 3. 6. 2. 1.]
Standard softmax: [0.006126 0.016652 0.045265 0.909178 0.016652 0.006126]
Online softmax: [0.006126 0.016652 0.045265 0.909178 0.016652 0.006126]
Max difference: 0.0
Let's trace through the example step by step. The input [1, 2, 3, 6, 2, 1] is split into two blocks of 3:
Block 1: [1, 2, 3]
- Block max:
m_block = 3 - Running max was
-∞, nowm = 3 - Running sum:
l = 0 × exp(-∞ - 3) + exp(1-3) + exp(2-3) + exp(3-3) = 0.135 + 0.368 + 1.0 = 1.503
Block 2: [6, 2, 1]
- Block max:
m_block = 6 - New max:
m_new = max(3, 6) = 6 - Rescale old sum:
l = 1.503 × exp(3 - 6) + exp(6-6) + exp(2-6) + exp(1-6) l = 1.503 × 0.0498 + 1.0 + 0.0183 + 0.00674 = 0.0749 + 1.0249 = 1.0998
That exp(3 - 6) = exp(-3) = 0.0498 correction factor is the key. It rescales all previous exponentials to be relative to the new, larger maximum. No information is lost. The final result is identical to computing softmax over all 6 values at once.
The online softmax trick is the mathematical heart of Flash Attention. It's what allows us to process the attention matrix tile by tile without ever assembling the full thing.
The Flash Attention Algorithm
Now we have all the pieces. Flash Attention divides the Q matrix into blocks of B_r rows and the K, V matrices into blocks of B_c rows. These block sizes are chosen so that a Q block (B_r × d), a K block (B_c × d), a V block (B_c × d), and the score tile (B_r × B_c) all fit in SRAM simultaneously.
The algorithm (using the Flash Attention 2 loop order — outer loop over Q blocks, inner loop over K/V blocks):
- For each Q block: load it into SRAM once
- Initialize an output accumulator
O = 0, running maxm = -∞, running suml = 0 - For each K/V block: load into SRAM, compute the B_r × B_c tile of attention scores, update running stats and output accumulator
- After all K/V blocks: normalize the output by dividing by
l, write the final result to HBM
The critical insight: the B_r × B_c score tile is computed in SRAM, used immediately to update the output accumulator, and then discarded. It never touches HBM. The full N × N attention matrix simply never exists.
def flash_attention(Q, K, V, B_r=32, B_c=32):
"""
Flash Attention (FA-2 loop order): exact attention without
materializing the N×N attention matrix.
Q, K, V: (N, d) arrays
B_r: block size for Q (number of query rows per tile)
B_c: block size for K/V (number of key rows per tile)
Returns: O (N, d) — same result as standard attention
"""
N, d = Q.shape
O = np.zeros((N, d), dtype=np.float32) # output accumulator (in HBM)
l = np.zeros((N, 1), dtype=np.float32) # running denominator per row
m = np.full((N, 1), -np.inf, dtype=np.float32) # running max per row
# Outer loop: iterate over Q blocks
for i in range(0, N, B_r):
i_end = min(i + B_r, N)
# ── Load Q block into SRAM ──
Q_block = Q[i:i_end] # (B_r, d)
O_block = O[i:i_end] # (B_r, d) accumulator for this Q block
l_block = l[i:i_end] # (B_r, 1)
m_block = m[i:i_end] # (B_r, 1)
# Inner loop: iterate over K/V blocks
for j in range(0, N, B_c):
j_end = min(j + B_c, N)
# ── Load K, V block into SRAM ──
K_block = K[j:j_end] # (B_c, d)
V_block = V[j:j_end] # (B_c, d)
# ── Compute attention scores for this tile (stays in SRAM) ──
S_tile = (Q_block @ K_block.T) / np.sqrt(d) # (B_r, B_c)
# ── Online softmax update ──
# Local max for this tile
m_tile = S_tile.max(axis=-1, keepdims=True) # (B_r, 1)
# New running max
m_new = np.maximum(m_block, m_tile) # (B_r, 1)
# Correction factor: rescale old accumulator to new max
alpha = np.exp(m_block - m_new) # (B_r, 1)
# Exponentiated scores, shifted to the running max
# Note: exp(S - m_tile) * exp(m_tile - m_new) = exp(S - m_new)
# so we skip the intermediate step and compute directly:
P_tile = np.exp(S_tile - m_new) # (B_r, B_c)
# Update running sum: rescale old sum + new tile sum
l_new = alpha * l_block + P_tile.sum(axis=-1, keepdims=True)
# Update output accumulator: rescale old output + new contribution
O_block = alpha * O_block + P_tile @ V_block
# Store updated stats
m_block = m_new
l_block = l_new
# ── Normalize and write final output to HBM ──
O[i:i_end] = O_block / l_block
l[i:i_end] = l_block
m[i:i_end] = m_block
return O
Let's walk through the key lines:
S_tile = Q_block @ K_block.T— This is the only matrix multiply with N×N total work, but each tile is just B_r × B_c. It lives in SRAM and is consumed immediately.m_new = np.maximum(m_block, m_tile)— Update the running max. If this tile has a score larger than anything we've seen before, we need to correct previous work.alpha = np.exp(m_block - m_new)— The correction factor. If the max didn't change,alpha = exp(0) = 1(no correction needed). If the new tile has a bigger max, alpha shrinks the old accumulator down.O_block = alpha * O_block + P_tile @ V_block— The running output. Old output is rescaled, then the new tile's contribution (attention-weighted values) is added.O_block / l_block— Final normalization. After processing all K/V blocks, divide by the total sum to complete the softmax.
Proving It's Exact
This is the claim that surprises people: Flash Attention computes the exact same result as standard attention. No approximation. No quality loss. It's a different algorithm for computing the same function.
Let's prove it empirically:
np.random.seed(42)
N, d = 256, 64
Q = np.random.randn(N, d).astype(np.float32)
K = np.random.randn(N, d).astype(np.float32)
V = np.random.randn(N, d).astype(np.float32)
# Standard attention
def naive_attention(Q, K, V):
S = Q @ K.T / np.sqrt(d)
P = np.exp(S - S.max(axis=-1, keepdims=True))
P = P / P.sum(axis=-1, keepdims=True)
return P @ V
O_naive = naive_attention(Q, K, V)
O_flash = flash_attention(Q, K, V, B_r=32, B_c=32)
max_diff = np.max(np.abs(O_naive - O_flash))
mean_diff = np.mean(np.abs(O_naive - O_flash))
print(f"Max absolute difference: {max_diff:.2e}")
print(f"Mean absolute difference: {mean_diff:.2e}")
print(f"Are they equal (tol=1e-5)? {np.allclose(O_naive, O_flash, atol=1e-5)}")
# Try different block sizes
for B_r, B_c in [(16, 16), (32, 64), (64, 32), (128, 128)]:
O_test = flash_attention(Q, K, V, B_r=B_r, B_c=B_c)
diff = np.max(np.abs(O_naive - O_test))
print(f"B_r={B_r:>3}, B_c={B_c:>3}: max diff = {diff:.2e}")
Output:
Max absolute difference: 2.27e-08
Mean absolute difference: 1.75e-09
Are they equal (tol=1e-5)? True
B_r= 16, B_c= 16: max diff = 2.27e-08
B_r= 32, B_c= 64: max diff = 2.27e-08
B_r= 64, B_c= 32: max diff = 2.27e-08
B_r=128, B_c=128: max diff = 2.27e-08
The maximum difference is on the order of 10-8 — that's floating-point rounding, not algorithmic error. The block sizes don't change the result because the online softmax correction is mathematically exact. Whether you process 16 keys at a time or 128, the final rescaled output is identical.
This is what makes Flash Attention so powerful: it's a free lunch in terms of accuracy. You get faster execution and lower memory usage without giving up a single digit of precision.
Counting the Memory Savings
Let's quantify the memory advantage. We'll count simulated HBM accesses for both algorithms:
def count_hbm_accesses(N, d, B_r, B_c):
"""Count HBM accesses for naive vs Flash Attention."""
# Naive attention:
# Read Q, K (2*N*d), write S (N*N), read S (N*N),
# write P (N*N), read P + V (N*N + N*d), write O (N*d)
naive = 4*N*d + 4*N*N
# Flash Attention:
# Outer loop has ceil(N/B_r) iterations
# Inner loop has ceil(N/B_c) iterations
# Each inner iteration: read K_block (B_c*d) + V_block (B_c*d)
# Each outer iteration: read Q_block (B_r*d), write O_block (B_r*d)
n_outer = (N + B_r - 1) // B_r
n_inner = (N + B_c - 1) // B_c
# Q blocks: loaded once per outer iteration
flash_q = n_outer * B_r * d
# K,V blocks: loaded once per (outer, inner) pair
flash_kv = n_outer * n_inner * 2 * B_c * d
# O blocks: written once per outer iteration
flash_o = n_outer * B_r * d
flash = flash_q + flash_kv + flash_o
return naive, flash
d = 64
B_r, B_c = 128, 128 # realistic SRAM tile size for modern GPUs
print(f"{'N':>6} {'Naive':>14} {'Flash':>14} {'Ratio':>8} {'Savings':>8}")
print("-" * 60)
for N in [256, 512, 1024, 2048, 4096]:
naive, flash = count_hbm_accesses(N, d, B_r, B_c)
ratio = naive / flash
savings = (1 - flash / naive) * 100
print(f"{N:>6} {naive:>14,} {flash:>14,} {ratio:>7.1f}x {savings:>6.1f}%")
Output:
N Naive Flash Ratio Savings
------------------------------------------------------------
256 327,680 98,304 3.3x 70.0%
512 1,179,648 327,680 3.6x 72.2%
1024 4,456,448 1,179,648 3.8x 73.5%
2048 17,301,504 4,456,448 3.9x 74.2%
4096 68,157,440 17,301,504 3.9x 74.6%
The pattern is clear: Flash Attention consistently uses about 70–75% less memory bandwidth than naive attention. At N=4096, the naive approach shuffles 68 million elements through HBM; Flash Attention manages the same computation with only 17 million. This is because naive attention's traffic is dominated by the 4N² term (reading and writing the full attention matrix three times), while Flash Attention never materializes that matrix — it only re-reads the N×d input blocks.
Formally, standard attention performs O(N² + Nd) HBM accesses, while Flash Attention performs O(N²d / B) where B is the block size. With realistic block sizes (B=128), Flash Attention saves a factor of roughly 4 — and the savings improve with larger SRAM and better tiling.
| What | Naive Attention | Flash Attention |
|---|---|---|
| Peak memory | O(N²) for attention matrix | O(N) for running stats |
| HBM accesses | O(N² + Nd) | O(N²d² / M) |
| FLOPs | O(N²d) | O(N²d) |
| Result | Exact | Exact (identical) |
The bottom row is the punch line. Same FLOPs, same result — just smarter data movement. Flash Attention doesn't do less work; it does the same work with fewer trips to slow memory.
Flash Attention 2 and Beyond
The algorithm we just implemented uses the Flash Attention 2 loop order (outer over Q, inner over K/V). The original Flash Attention paper reversed this — outer over K/V, inner over Q — which required writing partial results back to HBM and correcting them later. FA-2's loop swap was a key improvement: each Q block is loaded once, and its output is fully resolved before moving on. No partial writes, better parallelism.
FA-2 introduced two more optimizations:
- Deferred normalization — instead of normalizing by
linside the inner loop, accumulate the unnormalized output and divide once at the end. This eliminates repeated divisions and reduces non-matmul FLOPs. - Split-Q warp partitioning — within a thread block, different warps work on different parts of the Q block, each processing all K/V blocks independently. This eliminates inter-warp communication during the inner loop, keeping occupancy high.
The combined effect: FA-2 reaches about 70% of the H100's theoretical peak FLOP/s, versus ~50% for FA-1. That's a real-world 2× speedup on top of the already-significant bandwidth savings.
Flash Attention 3 and Flash-Decoding
Flash Attention 3 (2024) targets the NVIDIA Hopper architecture with hardware-specific tricks: warp specialization (dedicated producer and consumer warps for pipelining), GEMM-softmax overlap (compute the next tile's matrix multiply while applying softmax to the current tile), and FP8 support with block-level quantization. The result: 1.5–2× faster than FA-2, reaching 75% of H100 peak throughput.
Flash-Decoding solves a different problem. During autoregressive generation, you have a single query token attending to a potentially very long KV sequence. The standard Flash Attention parallelizes across Q blocks — but with one query, there's nothing to parallelize. Flash-Decoding flips the parallelism axis: it splits the KV sequence across thread blocks, each computing a partial output, then reduces the results at the end. This keeps the GPU saturated even for single-query decode.
These are all GPU-level implementation details, but the core algorithm — tiled computation with online softmax — is the same one we just built in NumPy.
See It in Action
Try It: Memory Traffic Visualizer
Watch how standard attention bounces the full N×N matrix through HBM, while Flash Attention processes small tiles in SRAM. The counter shows total bytes transferred — the gap grows dramatically with sequence length.
Try It: Online Softmax Explorer
Watch the online softmax algorithm process values block by block. The running max (m) and sum (l) update with each block. When a new block has a higher max, the rescaling correction kicks in. The final result matches standard softmax exactly.
IO-Awareness Changes Everything
Flash Attention teaches a lesson that applies far beyond attention mechanisms: algorithmic complexity (FLOPs) isn't the whole story. Two algorithms can have identical FLOP counts, but if one respects the memory hierarchy and the other doesn't, the performance difference is enormous.
Standard attention treats memory as flat. Flash Attention treats it as what it is — a hierarchy with 10× bandwidth differences between levels. The same math, computed with the same precision, just reorganized to keep the fast memory busy and the slow memory quiet.
Today, Flash Attention is everywhere. PyTorch's scaled_dot_product_attention uses it by default. HuggingFace Transformers enables it with a single flag. vLLM and llama.cpp use Flash Attention variants for serving. Every modern LLM you interact with — whether through a chatbot, a coding assistant, or an API — runs Flash Attention under the hood.
Our elementary series pipeline now looks like this: tokenize → embed → position (RoPE) → attend (Flash Attention) → KV cache → softmax → loss → optimize → decode
The attention matrix is the single largest intermediate tensor in a transformer. Flash Attention makes it vanish. That's not a minor optimization — it's the reason 100K+ token context windows are possible without quadratic memory. The next time a model reads an entire codebase or a full novel in one pass, remember: it's tiling all the way down.
References & Further Reading
- Tri Dao et al. — "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (2022) — The original paper that introduced IO-aware tiled attention
- Tri Dao — "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (2023) — Swapped loop order, deferred normalization, 2× speedup over FA-1
- Tri Dao et al. — "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision" (2024) — Hopper-specific optimizations, FP8, warp specialization
- Milakov & Gimelshein — "Online normalizer calculation for softmax" (2018) — The online softmax algorithm that enables tiled attention
- Stanford CRFM — "Flash-Decoding for long-context inference" (2023) — Parallelizing attention over the KV sequence for single-query decode
- DadOps — Attention from Scratch — Where we built the attention mechanism that Flash Attention accelerates
- DadOps — KV Cache from Scratch — Flash Attention during prefill, KV cache during decode — they solve different bottlenecks
- DadOps — RoPE from Scratch — Rotary embeddings are applied before the tiling loop begins
- DadOps — Transformer from Scratch — Flash Attention drops into the attention head as an implementation detail