← Back to Blog

Flash Attention from Scratch: Why the Fastest Attention Algorithm Never Materializes the Attention Matrix

The Memory Wall

Standard attention has a dirty secret. It's not slow because of math — it's slow because of memory.

In our attention from scratch post, we built dot-product attention: project inputs to Q, K, V, compute scores as QKT, apply softmax, multiply by V. Clean and correct. But there's a problem hiding in plain sight. That QKT matrix has N×N entries, where N is the sequence length. For a 32K-token prompt, that's over 1 billion entries per attention head, per layer. And the model writes every single one to memory, reads them all back for softmax, then reads them again to multiply by V.

The GPU isn't struggling with arithmetic. It's struggling with a traffic jam.

Flash Attention, introduced by Tri Dao in 2022, eliminates this bottleneck entirely. It computes exact attention — not an approximation, the very same result, bit for bit — without ever materializing the full N×N matrix. The trick is restructuring the computation to respect the GPU's memory hierarchy: keep small blocks in fast on-chip memory (SRAM), never write the giant intermediate matrix to slow off-chip memory (HBM).

In this post, we'll build Flash Attention from scratch in NumPy. We'll understand why GPUs have a memory wall, learn the online softmax trick that makes tiling possible, implement the full algorithm, and prove it produces identical results to naive attention. By the end, you'll understand the single most important systems optimization in modern deep learning — the algorithm that made 100K+ token contexts possible.

Why Attention Is Memory-Bound

To understand Flash Attention, you need to understand the GPU memory hierarchy. A modern GPU like the NVIDIA H100 has two kinds of memory:

Every GPU operation works the same way: load data from HBM into SRAM, do computation, write results back to HBM. The computation itself is cheap — modern tensor cores blaze through matrix multiplies. The bottleneck is the data movement.

Now look at what standard attention does:

Step 1: Load Q, K from HBM → Compute S = QKT → Write S (N×N) to HBM
Step 2: Load S from HBM → Compute P = softmax(S) → Write P (N×N) to HBM
Step 3: Load P, V from HBM → Compute O = PV → Write O (N×d) to HBM

Count the damage: Q and K are N×d each. S and P are N×N each. For a typical attention head with N=2048 and d=64, the Q and K matrices have 131K entries each, but S and P have 4.2 million entries each. The N×N intermediates dominate everything.

Let's make this concrete by instrumenting naive attention with a memory access counter:

import numpy as np

def naive_attention_counted(Q, K, V):
    """Standard attention with HBM access counting."""
    N, d = Q.shape
    hbm_reads = 0
    hbm_writes = 0

    # Step 1: Compute S = Q @ K^T
    # Read Q (N*d) and K (N*d) from HBM
    hbm_reads += 2 * N * d
    S = Q @ K.T / np.sqrt(d)
    # Write S (N*N) to HBM
    hbm_writes += N * N

    # Step 2: Softmax
    # Read S (N*N) from HBM
    hbm_reads += N * N
    P = np.exp(S - S.max(axis=-1, keepdims=True))
    P = P / P.sum(axis=-1, keepdims=True)
    # Write P (N*N) to HBM
    hbm_writes += N * N

    # Step 3: Output = P @ V
    # Read P (N*N) and V (N*d) from HBM
    hbm_reads += N * N + N * d
    O = P @ V
    # Write O (N*d) to HBM
    hbm_writes += N * d

    total_accesses = hbm_reads + hbm_writes
    return O, {
        'hbm_reads': hbm_reads,
        'hbm_writes': hbm_writes,
        'total': total_accesses,
        'n_squared_terms': 4 * N * N,   # S write + S read + P write + P read
        'n_d_terms': 4 * N * d          # Q, K reads + V read + O write
    }

# Example: N=2048, d=64
np.random.seed(42)
N, d = 2048, 64
Q = np.random.randn(N, d).astype(np.float32)
K = np.random.randn(N, d).astype(np.float32)
V = np.random.randn(N, d).astype(np.float32)

O_naive, stats = naive_attention_counted(Q, K, V)
print(f"Sequence length: {N}, Head dim: {d}")
print(f"HBM reads:  {stats['hbm_reads']:>12,}")
print(f"HBM writes: {stats['hbm_writes']:>12,}")
print(f"Total HBM:  {stats['total']:>12,}")
print(f"N² terms:   {stats['n_squared_terms']:>12,}  ({100*stats['n_squared_terms']/stats['total']:.0f}%)")
print(f"Nd terms:   {stats['n_d_terms']:>12,}  ({100*stats['n_d_terms']/stats['total']:.0f}%)")

Output:

Sequence length: 2048, Head dim: 64
HBM reads:     8,781,824
HBM writes:    8,519,680
Total HBM:    17,301,504
N² terms:     16,777,216  (97%)
Nd terms:        524,288  (3%)

97% of all memory traffic comes from the N×N intermediate matrices. The actual input and output (the Nd terms) are just 3% of the traffic. As N grows, this gets dramatically worse — memory traffic grows quadratically while the useful input/output grows linearly.

The GPU's tensor cores can churn through the actual matrix multiplications at incredible speed. But they spend most of their time waiting for data to shuttle between HBM and SRAM. This is what it means to be memory-bound: the computation is fast, but the plumbing is slow.

The Online Softmax Trick

Flash Attention's strategy is to process the attention matrix in tiles — small blocks that fit entirely in SRAM. But there's an immediate problem: softmax is a global operation.

Standard softmax over a row of scores requires three passes:

  1. Find the maximum value across all N scores (for numerical stability)
  2. Subtract the max, exponentiate all N values, and sum them
  3. Divide each value by the sum

You can't compute the final softmax for a tile of keys without knowing if a later tile contains a larger value. Or can you?

The online softmax algorithm, described by Milakov and Gimelshein in 2018, solves this elegantly. Instead of waiting to see all values, it processes them one block at a time, maintaining a running maximum and a running sum. When a new block arrives with a larger value, it corrects the previous partial results.

Here's the idea. Suppose you've processed block 1 and found a local max m1 and unnormalized sum l1 = ∑ exp(x - m1). Now block 2 arrives with a local max m2. If m2 > m1, all your previous exponentials were computed with the wrong baseline. The fix: multiply every previous term by exp(m1 - m2) to shift them to the new baseline. This rescaling is exact — no approximation.

def softmax_standard(x):
    """Standard 3-pass softmax."""
    m = np.max(x)                         # Pass 1: find max
    e = np.exp(x - m)                     # Pass 2: subtract max, exponentiate
    return e / np.sum(e)                  # Pass 3: normalize

def softmax_online(x, block_size=3):
    """Online softmax: processes values in blocks, maintaining running stats."""
    N = len(x)
    m = -np.inf   # running maximum
    l = 0.0       # running sum of exp(x_i - m)

    # First pass: compute global max and sum in blocks
    for start in range(0, N, block_size):
        block = x[start : start + block_size]
        m_block = np.max(block)

        # If this block has a new max, rescale previous sum
        m_new = max(m, m_block)
        l = l * np.exp(m - m_new) + np.sum(np.exp(block - m_new))
        m = m_new

    # Second pass: compute final softmax values
    return np.exp(x - m) / l

# Demonstrate equivalence
x = np.array([1.0, 2.0, 3.0, 6.0, 2.0, 1.0])
standard = softmax_standard(x)
online = softmax_online(x, block_size=3)

print("Input:", x)
print("Standard softmax:", np.round(standard, 6))
print("Online softmax:  ", np.round(online, 6))
print("Max difference:  ", np.max(np.abs(standard - online)))

Output:

Input: [1. 2. 3. 6. 2. 1.]
Standard softmax: [0.006126 0.016652 0.045265 0.909178 0.016652 0.006126]
Online softmax:   [0.006126 0.016652 0.045265 0.909178 0.016652 0.006126]
Max difference:   0.0

Let's trace through the example step by step. The input [1, 2, 3, 6, 2, 1] is split into two blocks of 3:

Block 1: [1, 2, 3]

Block 2: [6, 2, 1]

That exp(3 - 6) = exp(-3) = 0.0498 correction factor is the key. It rescales all previous exponentials to be relative to the new, larger maximum. No information is lost. The final result is identical to computing softmax over all 6 values at once.

The online softmax trick is the mathematical heart of Flash Attention. It's what allows us to process the attention matrix tile by tile without ever assembling the full thing.

The Flash Attention Algorithm

Now we have all the pieces. Flash Attention divides the Q matrix into blocks of B_r rows and the K, V matrices into blocks of B_c rows. These block sizes are chosen so that a Q block (B_r × d), a K block (B_c × d), a V block (B_c × d), and the score tile (B_r × B_c) all fit in SRAM simultaneously.

The algorithm (using the Flash Attention 2 loop order — outer loop over Q blocks, inner loop over K/V blocks):

  1. For each Q block: load it into SRAM once
  2. Initialize an output accumulator O = 0, running max m = -∞, running sum l = 0
  3. For each K/V block: load into SRAM, compute the B_r × B_c tile of attention scores, update running stats and output accumulator
  4. After all K/V blocks: normalize the output by dividing by l, write the final result to HBM

The critical insight: the B_r × B_c score tile is computed in SRAM, used immediately to update the output accumulator, and then discarded. It never touches HBM. The full N × N attention matrix simply never exists.

def flash_attention(Q, K, V, B_r=32, B_c=32):
    """
    Flash Attention (FA-2 loop order): exact attention without
    materializing the N×N attention matrix.

    Q, K, V: (N, d) arrays
    B_r: block size for Q (number of query rows per tile)
    B_c: block size for K/V (number of key rows per tile)
    Returns: O (N, d) — same result as standard attention
    """
    N, d = Q.shape
    O = np.zeros((N, d), dtype=np.float32)     # output accumulator (in HBM)
    l = np.zeros((N, 1), dtype=np.float32)     # running denominator per row
    m = np.full((N, 1), -np.inf, dtype=np.float32)  # running max per row

    # Outer loop: iterate over Q blocks
    for i in range(0, N, B_r):
        i_end = min(i + B_r, N)

        # ── Load Q block into SRAM ──
        Q_block = Q[i:i_end]           # (B_r, d)
        O_block = O[i:i_end]           # (B_r, d) accumulator for this Q block
        l_block = l[i:i_end]           # (B_r, 1)
        m_block = m[i:i_end]           # (B_r, 1)

        # Inner loop: iterate over K/V blocks
        for j in range(0, N, B_c):
            j_end = min(j + B_c, N)

            # ── Load K, V block into SRAM ──
            K_block = K[j:j_end]       # (B_c, d)
            V_block = V[j:j_end]       # (B_c, d)

            # ── Compute attention scores for this tile (stays in SRAM) ──
            S_tile = (Q_block @ K_block.T) / np.sqrt(d)   # (B_r, B_c)

            # ── Online softmax update ──
            # Local max for this tile
            m_tile = S_tile.max(axis=-1, keepdims=True)    # (B_r, 1)

            # New running max
            m_new = np.maximum(m_block, m_tile)            # (B_r, 1)

            # Correction factor: rescale old accumulator to new max
            alpha = np.exp(m_block - m_new)            # (B_r, 1)

            # Exponentiated scores, shifted to the running max
            # Note: exp(S - m_tile) * exp(m_tile - m_new) = exp(S - m_new)
            # so we skip the intermediate step and compute directly:
            P_tile = np.exp(S_tile - m_new)            # (B_r, B_c)

            # Update running sum: rescale old sum + new tile sum
            l_new = alpha * l_block + P_tile.sum(axis=-1, keepdims=True)

            # Update output accumulator: rescale old output + new contribution
            O_block = alpha * O_block + P_tile @ V_block

            # Store updated stats
            m_block = m_new
            l_block = l_new

        # ── Normalize and write final output to HBM ──
        O[i:i_end] = O_block / l_block
        l[i:i_end] = l_block
        m[i:i_end] = m_block

    return O

Let's walk through the key lines:

Proving It's Exact

This is the claim that surprises people: Flash Attention computes the exact same result as standard attention. No approximation. No quality loss. It's a different algorithm for computing the same function.

Let's prove it empirically:

np.random.seed(42)
N, d = 256, 64

Q = np.random.randn(N, d).astype(np.float32)
K = np.random.randn(N, d).astype(np.float32)
V = np.random.randn(N, d).astype(np.float32)

# Standard attention
def naive_attention(Q, K, V):
    S = Q @ K.T / np.sqrt(d)
    P = np.exp(S - S.max(axis=-1, keepdims=True))
    P = P / P.sum(axis=-1, keepdims=True)
    return P @ V

O_naive = naive_attention(Q, K, V)
O_flash = flash_attention(Q, K, V, B_r=32, B_c=32)

max_diff = np.max(np.abs(O_naive - O_flash))
mean_diff = np.mean(np.abs(O_naive - O_flash))

print(f"Max absolute difference:  {max_diff:.2e}")
print(f"Mean absolute difference: {mean_diff:.2e}")
print(f"Are they equal (tol=1e-5)? {np.allclose(O_naive, O_flash, atol=1e-5)}")

# Try different block sizes
for B_r, B_c in [(16, 16), (32, 64), (64, 32), (128, 128)]:
    O_test = flash_attention(Q, K, V, B_r=B_r, B_c=B_c)
    diff = np.max(np.abs(O_naive - O_test))
    print(f"B_r={B_r:>3}, B_c={B_c:>3}: max diff = {diff:.2e}")

Output:

Max absolute difference:  2.27e-08
Mean absolute difference: 1.75e-09
Are they equal (tol=1e-5)? True
B_r= 16, B_c= 16: max diff = 2.27e-08
B_r= 32, B_c= 64: max diff = 2.27e-08
B_r= 64, B_c= 32: max diff = 2.27e-08
B_r=128, B_c=128: max diff = 2.27e-08

The maximum difference is on the order of 10-8 — that's floating-point rounding, not algorithmic error. The block sizes don't change the result because the online softmax correction is mathematically exact. Whether you process 16 keys at a time or 128, the final rescaled output is identical.

This is what makes Flash Attention so powerful: it's a free lunch in terms of accuracy. You get faster execution and lower memory usage without giving up a single digit of precision.

Counting the Memory Savings

Let's quantify the memory advantage. We'll count simulated HBM accesses for both algorithms:

def count_hbm_accesses(N, d, B_r, B_c):
    """Count HBM accesses for naive vs Flash Attention."""
    # Naive attention:
    # Read Q, K (2*N*d), write S (N*N), read S (N*N),
    # write P (N*N), read P + V (N*N + N*d), write O (N*d)
    naive = 4*N*d + 4*N*N

    # Flash Attention:
    # Outer loop has ceil(N/B_r) iterations
    # Inner loop has ceil(N/B_c) iterations
    # Each inner iteration: read K_block (B_c*d) + V_block (B_c*d)
    # Each outer iteration: read Q_block (B_r*d), write O_block (B_r*d)
    n_outer = (N + B_r - 1) // B_r
    n_inner = (N + B_c - 1) // B_c

    # Q blocks: loaded once per outer iteration
    flash_q = n_outer * B_r * d
    # K,V blocks: loaded once per (outer, inner) pair
    flash_kv = n_outer * n_inner * 2 * B_c * d
    # O blocks: written once per outer iteration
    flash_o = n_outer * B_r * d
    flash = flash_q + flash_kv + flash_o

    return naive, flash

d = 64
B_r, B_c = 128, 128  # realistic SRAM tile size for modern GPUs
print(f"{'N':>6}  {'Naive':>14}  {'Flash':>14}  {'Ratio':>8}  {'Savings':>8}")
print("-" * 60)
for N in [256, 512, 1024, 2048, 4096]:
    naive, flash = count_hbm_accesses(N, d, B_r, B_c)
    ratio = naive / flash
    savings = (1 - flash / naive) * 100
    print(f"{N:>6}  {naive:>14,}  {flash:>14,}  {ratio:>7.1f}x  {savings:>6.1f}%")

Output:

     N           Naive           Flash     Ratio  Savings
------------------------------------------------------------
   256         327,680          98,304      3.3x    70.0%
   512       1,179,648         327,680      3.6x    72.2%
  1024       4,456,448       1,179,648      3.8x    73.5%
  2048      17,301,504       4,456,448      3.9x    74.2%
  4096      68,157,440      17,301,504      3.9x    74.6%

The pattern is clear: Flash Attention consistently uses about 70–75% less memory bandwidth than naive attention. At N=4096, the naive approach shuffles 68 million elements through HBM; Flash Attention manages the same computation with only 17 million. This is because naive attention's traffic is dominated by the 4N² term (reading and writing the full attention matrix three times), while Flash Attention never materializes that matrix — it only re-reads the N×d input blocks.

Formally, standard attention performs O(N² + Nd) HBM accesses, while Flash Attention performs O(N²d / B) where B is the block size. With realistic block sizes (B=128), Flash Attention saves a factor of roughly 4 — and the savings improve with larger SRAM and better tiling.

What Naive Attention Flash Attention
Peak memory O(N²) for attention matrix O(N) for running stats
HBM accesses O(N² + Nd) O(N²d² / M)
FLOPs O(N²d) O(N²d)
Result Exact Exact (identical)

The bottom row is the punch line. Same FLOPs, same result — just smarter data movement. Flash Attention doesn't do less work; it does the same work with fewer trips to slow memory.

Flash Attention 2 and Beyond

The algorithm we just implemented uses the Flash Attention 2 loop order (outer over Q, inner over K/V). The original Flash Attention paper reversed this — outer over K/V, inner over Q — which required writing partial results back to HBM and correcting them later. FA-2's loop swap was a key improvement: each Q block is loaded once, and its output is fully resolved before moving on. No partial writes, better parallelism.

FA-2 introduced two more optimizations:

The combined effect: FA-2 reaches about 70% of the H100's theoretical peak FLOP/s, versus ~50% for FA-1. That's a real-world 2× speedup on top of the already-significant bandwidth savings.

Flash Attention 3 and Flash-Decoding

Flash Attention 3 (2024) targets the NVIDIA Hopper architecture with hardware-specific tricks: warp specialization (dedicated producer and consumer warps for pipelining), GEMM-softmax overlap (compute the next tile's matrix multiply while applying softmax to the current tile), and FP8 support with block-level quantization. The result: 1.5–2× faster than FA-2, reaching 75% of H100 peak throughput.

Flash-Decoding solves a different problem. During autoregressive generation, you have a single query token attending to a potentially very long KV sequence. The standard Flash Attention parallelizes across Q blocks — but with one query, there's nothing to parallelize. Flash-Decoding flips the parallelism axis: it splits the KV sequence across thread blocks, each computing a partial output, then reduces the results at the end. This keeps the GPU saturated even for single-query decode.

These are all GPU-level implementation details, but the core algorithm — tiled computation with online softmax — is the same one we just built in NumPy.

See It in Action

Try It: Memory Traffic Visualizer

Watch how standard attention bounces the full N×N matrix through HBM, while Flash Attention processes small tiles in SRAM. The counter shows total bytes transferred — the gap grows dramatically with sequence length.

HBM traffic (slow) SRAM compute (fast) Current tile
Step: 0
Naive HBM Traffic 0
Flash HBM Traffic 0
Savings

Try It: Online Softmax Explorer

Watch the online softmax algorithm process values block by block. The running max (m) and sum (l) update with each block. When a new block has a higher max, the rescaling correction kicks in. The final result matches standard softmax exactly.

Running Max (m) −∞
Running Sum (l) 0
Blocks Processed 0 / 0

IO-Awareness Changes Everything

Flash Attention teaches a lesson that applies far beyond attention mechanisms: algorithmic complexity (FLOPs) isn't the whole story. Two algorithms can have identical FLOP counts, but if one respects the memory hierarchy and the other doesn't, the performance difference is enormous.

Standard attention treats memory as flat. Flash Attention treats it as what it is — a hierarchy with 10× bandwidth differences between levels. The same math, computed with the same precision, just reorganized to keep the fast memory busy and the slow memory quiet.

Today, Flash Attention is everywhere. PyTorch's scaled_dot_product_attention uses it by default. HuggingFace Transformers enables it with a single flag. vLLM and llama.cpp use Flash Attention variants for serving. Every modern LLM you interact with — whether through a chatbot, a coding assistant, or an API — runs Flash Attention under the hood.

Our elementary series pipeline now looks like this: tokenize → embed → position (RoPE) → attend (Flash Attention) → KV cache → softmax → loss → optimize → decode

The attention matrix is the single largest intermediate tensor in a transformer. Flash Attention makes it vanish. That's not a minor optimization — it's the reason 100K+ token context windows are possible without quadratic memory. The next time a model reads an entire codebase or a full novel in one pass, remember: it's tiling all the way down.

References & Further Reading