kvcache.cobanov.dev

Built by Mert Cobanov

KV Cache &
Flash Attention.

LLM inference optimizations are one of the parts of this field I keep coming back to. New techniques land at a steady clip, a smarter cache, a faster kernel, a better scheduler, and each is a specific, careful answer to a specific bottleneck. I built this page to work through the most important of them, as much for my own understanding as for whoever else reads it, visualized as clearly as I could manage. Eleven sections, top to bottom, each pairing a short explanation with Python and an animated diagram.

These started as my own notes while working through Rohit Ghumare’s excellent ai-engineering-from-scratch repo. I kept rewriting them with my own visualizations until they made sense, and this page is the result.

The baseline

01

Naive autoregressive decoding

A decoder-only transformer produces text one token at a time. During training, the entire sequence is fed in parallel and attention is computed once over the whole input. At inference time this is no longer possible: token cannot be sampled until token has been emitted.

The simplest implementation handles this serial constraint by running the model from scratch at every step. Each new token triggers a fresh forward pass over the entire sequence so far: every position’s Q, K, and V vectors are recomputed, the full attention matrix is materialized, and only the last row of the output is actually used. The work done for positions 0 through is identical to the work done in the previous step. It is thrown away and recomputed.

The cost scales as the triangular sum: emitting N tokens costs attention rows, i.e. in the sequence length. A 100-token completion takes about 5,050 attention operations; a 4,096-token completion takes more than eight million. Most of that work is redundant.

naive_decode.pypython
def naive_decode(prompt_tokens, n_new):
    tokens = list(prompt_tokens)
    ops = 0
    for step in range(n_new):
        # recompute Q, K, V for the entire sequence every step
        Q = project_q(tokens)            # (N, d)
        K = project_k(tokens)            # (N, d)
        V = project_v(tokens)            # (N, d)

        attn = softmax(Q @ K.T) @ V      # full N x N matrix
        next_tok = sample(attn[-1])      # only the last row is used

        tokens.append(next_tok)
        ops += len(tokens) ** 2          # ~O(N^2) per step
    return tokens, ops
naive decoder, per-step work
tokens generated: 8attention ops: 36

Each row is one decode step. Filled cells along the row are the attention operations performed for that step. The highlighted row is the current step; the work in dimmer rows above it was already done and then thrown away. Total filled cells = N(N+1)/2.

Drag the slider to grow N. Watch the triangular area: doubling N roughly quadruples the filled cells.

Two observations follow from this. First, the K and V vectors for any prefix token are a deterministic function of that token and the model weights; they do not change between steps. Recomputing them is wasted compute. Second, of the full N×N attention output, only one row (attn[-1]) is consumed at each step. The other rows are computed and discarded.

These two observations motivate the KV cache, which is the subject of the next section.

First optimization

02

The KV cache

The and projections for any prefix token are pure functions of that token and the model weights. Their value at step is identical to their value at step . If we keep the and vectors around after computing them once, the next decode step can simply read them and skip the projection entirely.

The fix is straightforward: keep a per-layer, per-head buffer and append to it every time we produce a new token. Each decode step projects a single new token’s , appends and to the cache, and runs attention as a row against the cached keys. The per-step cost drops from per-position projections to one. Across a generation of tokens, the total work falls from to .

For that means the naive decoder does attention rows while the cached decoder does . The same gap applies to the projections.

kv_cache.pypython
class KVCache:
    def __init__(self, n_layers, n_heads, d_head):
        self.K = [[[] for _ in range(n_heads)] for _ in range(n_layers)]
        self.V = [[[] for _ in range(n_heads)] for _ in range(n_layers)]

    def append(self, layer, head, k, v):
        self.K[layer][head].append(k)
        self.V[layer][head].append(v)

    def read(self, layer, head):
        return self.K[layer][head], self.V[layer][head]


def cached_decode(prompt, n_new, cache):
    # prefill: encode the prompt once and fill the cache
    for tok in prompt:
        k, v = project_kv(tok)
        cache.append(layer=0, head=0, k=k, v=v)

    # decode: only the new token's Q is projected each step
    out = []
    for _ in range(n_new):
        q = project_q(out[-1] if out else prompt[-1])
        K, V = cache.read(0, 0)

        attn = softmax(q @ stack(K).T) @ stack(V)   # (1, N)
        next_tok = sample(attn)

        k, v = project_kv(next_tok)
        cache.append(0, 0, k, v)
        out.append(next_tok)
    return out
cached decoder, per-step work
tokens
8
ops (cached)
8
vs naive
22%

Orange cells are computed this step. Green cells are read from the cache without recomputation. Dimmer rows above are previous steps; their orange cells became green entries in the cache for later steps.

Orange = newly computed this step. Green = served from the cache.

Two practical consequences. First, generation now has two distinct phases. The prefill phase processes the entire prompt in one parallel forward pass to populate the cache. The decode phase runs one token at a time against the cache. These have very different arithmetic intensity profiles and production engines schedule them differently.

Second, the savings in compute are paid for in memory. Every cached K and V vector lives in GPU HBM until the sequence finishes. For short sequences this is irrelevant. For 32K- or 128K-token contexts on large models it becomes the dominant constraint, which is what we look at next.

The tradeoff

03

The cost of keeping K and V around

The KV cache turns compute savings into a memory tax. Per layer, per token, the cache stores one K vector and one V vector, each of width d_head times the number of KV heads. With fp16, that works out to:

For Llama 3 8B (32 layers, 8 KV heads under GQA, d_head 128, fp16) the cache costs 128 KB per token, which is about 4 GB for a 32K context. Llama 3 70B has the same per-head width but 80 layers, so each token costs roughly 320 KB and a 32K context needs about 10 GB. At 128K context, the cache alone for 70B is around 40 GB, the majority of an A100 before any model weights are loaded.

Two architectural decisions make this manageable. Grouped Query Attention (GQA) decouples the number of K/V heads from the number of Q heads: Llama 3 has 64 query heads sharing only 8 KV heads, cutting the cache by 8x relative to multi-head attention. Multi-head Latent Attention (MLA), used in DeepSeek V2/V3, goes further by projecting K and V into a smaller shared latent space and decompressing on demand.

kv_cache_size.pypython
def gqa_kv_bytes(n_tokens, n_layers, n_kv_heads, d_head, dtype_b=2, batch=1):
    """Cache size for GQA or MHA: store one K and one V per kv head per token."""
    return 2 * batch * n_tokens * n_layers * n_kv_heads * d_head * dtype_b


def mla_kv_bytes(n_tokens, n_layers, kv_lora_rank, qk_rope_dim, dtype_b=2, batch=1):
    """Cache size for MLA: store the compressed latent + RoPE channel only."""
    return batch * n_tokens * n_layers * (kv_lora_rank + qk_rope_dim) * dtype_b


gqa = {
    "Llama 3.1 8B":        dict(n_layers=32, n_kv_heads=8,  d_head=128),
    "Llama 3.1 70B":       dict(n_layers=80, n_kv_heads=8,  d_head=128),
    "Llama 3.1 405B":      dict(n_layers=126,n_kv_heads=16, d_head=128),
    "Qwen 2.5 72B":        dict(n_layers=80, n_kv_heads=8,  d_head=128),
    "Mistral Large 2 123B":dict(n_layers=88, n_kv_heads=8,  d_head=128),
}
mla = {
    "DeepSeek V3 671B":    dict(n_layers=61, kv_lora_rank=512, qk_rope_dim=64),
    "Kimi K2 1T":          dict(n_layers=61, kv_lora_rank=512, qk_rope_dim=64),
}

for name, cfg in gqa.items():
    gb = gqa_kv_bytes(32_000, **cfg) / 1024**3
    print(f"{name:24s} {gb:6.2f} GB @ 32K context")

for name, cfg in mla.items():
    gb = mla_kv_bytes(32_000, **cfg) / 1024**3
    print(f"{name:24s} {gb:6.2f} GB @ 32K context")

# Llama 3.1 8B               3.91 GB
# Llama 3.1 70B              9.77 GB
# Llama 3.1 405B            30.79 GB
# Qwen 2.5 72B               9.77 GB
# Mistral Large 2 123B      10.74 GB
# DeepSeek V3 671B           2.10 GB   <- MLA wins, despite being bigger
# Kimi K2 1T                 2.10 GB
KV cache vs HBM ceiling
scopeThis calculator shows the KV cache only, per replica, for one batch. Model weights are separate; for large models most of the GPU is already taken by weights before any cache is allocated. Each preset includes the typical deployment shape.
KV cache only9.77 GB
0H100 80 GB HBM (cache only)80 GB
bytes / token / layer
4096 B
bytes / token
320.0 KB
cache / GPU HBM
12.2%
weightsWeights ~140 GB at bf16; typical deployment is 2-4 H100s with tensor parallel.
Try Llama 3 70B at 128K context, then switch to DeepSeek V3. MLA is the reason the 671B model has a smaller cache than the 70B one.

The KV cache is also why throughput-oriented servers quantize the cache (fp8, int4) before they quantize anything else: a single byte cut per element is multiplied by every layer, every head, every token, every concurrent sequence. At 70B and 32K context, switching from fp16 to fp8 halves the cache from roughly 10 GB to 5 GB, which is the difference between fitting two sequences on a single H100 and fitting one.

Why kernels matter

04

The memory bandwidth bottleneck

The KV cache solved a compute problem. There is still a data-movement problem underneath it. At every decode step a GPU runs roughly : a query against a length-N key tensor, a softmax, and a length-N value tensor. The arithmetic is modest. The data movement is not.

Modern GPUs have two memory tiers. HBM is the large off-chip pool: 80 GB on an H100, 192 GB on a B200, around 3 TB/s of bandwidth. SRAM is the on-chip scratchpad inside each streaming multiprocessor: roughly 256 KB per SM on H100, but around 30 TB/s, ten times faster than HBM. Every time a computation goes out to HBM and back, the kernel pays that 10x factor.

The naive attention kernel computes , writes the score matrix to HBM, reads it back to compute softmax, writes back, then reads it once more to multiply by V. For N = 4096 in fp16, that is 32 MB shuffled to and from HBM three times, per layer, per head. From onward, attention becomes memory-bound; the matrix multiply units run at a small fraction of their peak FLOPs because the kernel is waiting on bytes.

standard_attention.pypython
def standard_attention(Q, K, V):
    # Q, K, V live in HBM.
    S = Q @ K.T            # materialize N x N in HBM
    P = softmax(S)         # read S from HBM, write P to HBM
    O = P @ V              # read P from HBM, read V, write O
    return O

# For N = 4096, fp16:
#   S, P each occupy   2 * 4096 * 4096  = 32 MB
#   Three HBM round trips: ~96 MB shuffled per attention call,
#   per layer, per head. The compute is cheap; the traffic is not.
standard attention through the memory hierarchy
sequence length N
HBM80 GB3 TB/sSRAM256 KB · 30 TB/s1.0 MB
stage 1 of 111.0 MB · HBM → SRAM
Load Q from HBM
Q is N×d. A few hundred KB to a few MB.
cumulative HBM traffic so far1.0 MB

The full kernel pushes 132.0 MB through HBM at N = 4,096. Three of those legs are the same N×N intermediate, each one 32.0 MB. The actual answer (the O tensor) is only 1.0 MB.

load Q,K,Vread/write S, Pwrite O
Press run. Watch the eleven stages: each red bar is an N×N tensor being shuffled between HBM and SRAM. Try changing N.

The fix is not to add more FLOPs but to rearrange where they happen. If the full matrix never has to live in HBM, if we can do softmax and the value multiply while the relevant slice still lives in SRAM, the bandwidth bill collapses. That is the observation Flash Attention turns into an algorithm, which is the subject of the next section.

Rearranging the work

05

Flash Attention: tiling the N×N matrix out of HBM

The full attention output never depends on having the entire matrix in one place. Each output row only needs the corresponding row of and all of , . We can split Q into row blocks and K, V into column blocks, then loop: load one Q block and one K-V block into SRAM, compute the partial scores, partially accumulate the output, repeat. The intermediate scores never leave the chip.

The catch is softmax. It is a global operation: the normalization constant depends on the max and sum across all N scores. If we process K in pieces, the max and sum we see in any one tile are partial. The trick (usually attributed to Milakov and Gimelshein, and put to work in this kernel by Dao et al.) is to keep a running pair: the largest score seen so far and the rescaled sum so far. When a new tile arrives we adjust both, rescale the partial output accordingly, and add the new tile’s contribution. The final division gives bit-exact equality with naive softmax (modulo fp16 non-associativity in the additions).

The end result is that HBM traffic for attention drops from to . For N = 4096 in fp16, that is roughly a 50x cut in bytes shuffled between HBM and SRAM. The actual wall-clock speedup depends on the kernel and the GPU but is typically 2 to 4 on A100 and 5 to 10 on H100 with FP8.

flash_attention.pypython
def flash_attention(Q, K, V, tile=64):
    N, d = Q.shape
    O = zeros_like(Q)

    for i in range(0, N, tile):                   # outer loop over Q tiles
        Q_i = Q[i:i + tile]                       # load to SRAM

        m_i = full(tile, -inf)                    # running max
        l_i = zeros(tile)                         # running sum
        O_i = zeros((tile, d))                    # running output

        for j in range(0, N, tile):               # inner loop over K, V tiles
            K_j = K[j:j + tile]
            V_j = V[j:j + tile]

            S_ij  = Q_i @ K_j.T                   # SRAM only, never to HBM
            m_new = maximum(m_i, S_ij.max(-1))
            P_ij  = exp(S_ij - m_new[:, None])
            l_i   = exp(m_i - m_new) * l_i + P_ij.sum(-1)
            O_i   = exp(m_i - m_new)[:, None] * O_i + P_ij @ V_j
            m_i   = m_new

        O[i:i + tile] = O_i / l_i[:, None]        # one HBM write per Q tile

    return O
tiled traversal of the attention matrix
tile
K columns →Q rows ↓
SRAM working set
Q tile 0
K tile 0
V tile 0
running (m, ℓ) × 4
partial O × 4
iteration
outer i = 0 / 3
inner j =
ready
accumulators for Q tile 0
rowm
0−∞0.00
1−∞0.00
2−∞0.00
3−∞0.00
HBM traffic so far
flash0 B
naive0 B
Toy values (N=16, d=8, fp16).
active tileaccumulated into current Qfinalized Q tile

Outer loop walks the Q row blocks. Inside, an inner loop walks every K and V column block. Each (Qi, Kj, Vj) triple is loaded into SRAM, used to refresh the running (m, ℓ) and partial O for the rows in Qi, and discarded. When the inner loop ends, the finalized O tile is written out and we move to the next Q. The score matrix never exists as a single object; it is consumed tile by tile in place.

Each block is loaded into SRAM, used, and discarded. The full N×N matrix is never materialized. Step through manually or press run to watch the inner loop scan K for each Q row block.

Three things matter for the rest of this page. First, Flash Attention is exact, not approximate. Output and gradients match the unfused kernel. Second, the algorithmic core is the online softmax; the kernel-level work is choosing tile sizes and scheduling to fit the GPU’s SRAM and warp structure. Third, every major version bump (v1 to v2 to v3 to v4) keeps the algorithm and rewrites the scheduling to match a new GPU architecture. We look at the online softmax in isolation next, then at the version history.

The numerical heart of Flash Attention

06

Online softmax

Softmax is defined globally: each element of the output depends on the sum of exponentials over the full input. In a streaming setting we only see one tile of scores at a time, so the obvious approach (compute , then sum , then divide) would require two full passes over the scores. The online variant collapses this into a single pass by maintaining two running quantities: , the largest score seen so far, and , the sum of over everything seen so far.

When a new tile arrives, two adjustments happen. The running max may have to grow, in which case the running sum needs to be rescaled by so it expresses everything relative to the new offset. The new tile’s contribution is then added in. After all tiles have been seen, dividing each by the final gives the exact softmax.

Flash Attention fuses this with the matrix multiply: the running rescale factor is applied not just to the running sum but also to the running output tile, which is itself an accumulator over contributions. So the algorithm never needs to revisit prior tiles. One linear sweep, one final division, bit-exact result.

online_softmax.pypython
import math

def online_softmax(scores, tile=4):
    """Numerically stable softmax computed in a single pass over tiles.
    Output is bit-identical to the naive version (modulo fp non-associativity).
    """
    m = -math.inf      # running max
    l = 0.0            # running sum of exponentials

    # one pass to learn m and l
    for s in range(0, len(scores), tile):
        block = scores[s:s + tile]
        m_new = max(m, *block)
        scale = math.exp(m - m_new) if m != -math.inf else 0.0
        l     = l * scale + sum(math.exp(x - m_new) for x in block)
        m     = m_new

    # second pass to emit normalized probabilities
    return [math.exp(x - m) / l for x in scores]


# In Flash Attention the second pass is fused into the same tile loop:
# each tile's output contribution is rescaled by exp(m_old - m_new) when
# the running max advances, so no scores need to be revisited.
tile-by-tile softmax over 12 scores
tile
1 / 3
running m
2.100
running ℓ
1.761
partial softmax · exp(xi − m) / ℓ for each position
currentabsorbednaive final
tile 0current
1.2
0
-0.4
1
0.8
2
2.1
3
tile 1pending
0.3
4
-1.5
5
1.8
6
0.7
7
tile 2pending
-0.2
8
2.4
9
1.1
10
0.5
11
tile 0 · how (m, ℓ) updated
block maxmax(1.2, -0.4, 0.8, 2.1)= 2.100
m_newmax(m_old, block_max) = max(−∞, 2.100)= 2.100
scale(initial — ℓ_old = 0)=
ℓ_newΣ exp(x − m_new) over tile= 1.761

First tile sets the baseline. No previous sum to rescale, so scale is undefined and ℓ is just the sum of this tile's exponents.

Step through each tile to watch m and ℓ evolve. The key move is the rescale when a new max arrives. The dashed line on the bar chart is the final naive softmax for reference.

Two things to internalize. First, this is not a compression or approximation. Every probability is the same, to floating-point rounding, as the unfused version. Second, the trick is general: any operation of the form “normalize across N then aggregate” can be made streaming the same way. Layer norm, RMS norm, and the backward pass through softmax all admit similar treatments, and Flash Attention’s newer kernels exploit this on the gradient path too.

The kernel keeps moving

07

Four versions, four GPU generations

Flash Attention has been rewritten three times since 2022. The algorithm itself (tiling plus online softmax) is unchanged. What changes is the schedule: how the tiles map to warps, how memory copies overlap with math, which numeric formats are used, which special instructions on which generation of tensor core the kernel relies on.

v1 (2022) was the first end-to-end fused implementation on A100. v2 (2023) reworked the parallelization to keep the warps busy on causal workloads, where the upper-right of the matrix is masked out. v3 (2024) was designed around H100 features: warp specialization, asynchronous TMA copies, FP8 with per-block scaling. v4 (2026) targets Blackwell’s deeper tensor memory hierarchy with a five-stage compute/copy pipeline and replaces the hardware exp by a software exp2 sequence that maps better onto the new tensor cores.

From a deployment perspective, the choice is determined by the GPU and whether the workload is forward-only or also needs gradients. v4 launched forward-only; training on Blackwell still uses v3.

pick_flash.pypython
def pick_flash(gpu: str, training: bool) -> str:
    """Pick the right FlashAttention kernel for the deployment."""
    if gpu in ("B200", "GB200") and not training:
        return "FlashAttention-4"
    if gpu in ("H100", "H200"):
        return "FlashAttention-3"
    if gpu in ("A100", "A6000", "L40S"):
        return "FlashAttention-2"
    return "FlashAttention-2 (safe default)"


# Practical defaults (mid-2026):
#   - Training on Hopper:     FA3 (forward + backward).
#   - Inference on Blackwell: FA4 (forward only at launch).
#   - Anything older:         FA2.
version timeline

Hopper async, FP8

H100 · 2024

Warp specialization with async memory copies. FP8 path with per-block scaling. Designed around Hopper's WGMMA.

utilization (approx, fp16 forward)
~740 TFLOPs/s fp16, ~75% of H100 SOL; 1.2 PFLOPs/s FP8.
Click a node. Each version is matched to its GPU generation; performance figures are approximate fp16 utilization.

The reason this matters in practice: the same model, with the same weights, running the same logical attention, can vary by a factor of three to five in wall-clock latency depending on which kernel is wired in. The difference between a workload feeling responsive and feeling sluggish is almost never the model. It is the kernel below the model.

Allocation, not arithmetic

08

PagedAttention: KV cache as virtual memory

The default way to allocate KV cache is to give each sequence a contiguous slab of HBM. That works for one sequence. For a server handling many sequences at once it breaks down. Sequence lengths vary wildly. A one-word reply and a thousand-token explanation can arrive in the same batch. None of the obvious allocation strategies does well across that spread.

Reserve the maximum context length for everyone and most of the cache sits idle. Allocate the current length and grow it as the sequence extends, and each grow event needs a reallocation and copy under HBM bandwidth pressure. Guess the expected length up front and you waste memory when you guess high or stall when you guess low. Each choice is bad in a different way.

The harder problem arrives when sequences finish. A sequence that ends early leaves a hole in HBM. The next request, of a different size, generally cannot use that hole because contiguous memory has to be, well, contiguous. Production traces from vLLM in 2023 showed 60-80% of the KV pool sitting idle inside such fragments.

The PagedAttention proposal, introduced by the vLLM team, ports an idea from operating-system virtual memory to KV cache. Carve HBM into fixed-size physical blocks (16 tokens by default). Give each sequence a small page table mapping its logical token positions to the physical blocks that hold them. New tokens grow the page table by appending another block. A finished sequence returns all of its blocks to a free list. The attention kernel takes the page table as an extra input and gathers K, V from the physical blocks at lookup time.

Three consequences follow. Fragmentation drops to zero because every block is either fully owned by a sequence or fully free. The same HBM serves substantially more concurrent sequences, roughly 4x throughput at fixed memory in the original paper. And the page-table indirection makes prefix sharing trivial: two sequences that share a prefix can point to the same physical blocks, no copy, no recomputation. For agent workloads with repeated tool prompts, the prefix sharing alone is worth 5x to 10x in throughput.

paged_kv_cache.pypython
class PagedKVCache:
    """KV cache as virtual memory: fixed blocks + per-sequence page table."""
    BLOCK = 16  # tokens per physical block (vLLM default)

    def __init__(self, total_blocks: int):
        self.free: list[int] = list(range(total_blocks))
        self.page_table: dict[int, list[int]] = {}

    def allocate(self, seq_id: int, n_tokens: int) -> None:
        n_blocks = -(-n_tokens // self.BLOCK)        # ceil-div
        self.page_table[seq_id] = [self.free.pop() for _ in range(n_blocks)]

    def append_token(self, seq_id: int, current_len: int) -> None:
        # Only grow when crossing a block boundary.
        if current_len % self.BLOCK == 0:
            self.page_table[seq_id].append(self.free.pop())

    def release(self, seq_id: int) -> None:
        self.free.extend(self.page_table.pop(seq_id))

    def physical(self, seq_id: int, logical_pos: int) -> tuple[int, int]:
        block_idx = self.page_table[seq_id][logical_pos // self.BLOCK]
        return block_idx, logical_pos % self.BLOCK

    def share_prefix(self, src_id: int, dst_id: int, prefix_len: int) -> None:
        """Two sequences point to the same physical blocks for a shared prefix."""
        n = -(-prefix_len // self.BLOCK)
        self.page_table[dst_id] = self.page_table[src_id][:n].copy()
contiguous vs paged, walked through step by step
step 1 / 6
Empty pool

Both allocators start empty. The pool is 64 token slots, drawn here as 16 physical blocks of 4 tokens each. vLLM uses 16-token blocks in production; the smaller value here is only to keep the diagram readable.

contiguous allocator
bump pointer + holes
live0 / 64
wasted0
paged allocator
4-token blocks + page table
live0 / 64
wasted0
block 0block 1block 2block 3block 4block 5block 6block 7block 8block 9block 10block 11block 12block 13block 14block 15
free list[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
Press auto or step through. Watch how the contiguous side accumulates a hole it cannot reuse, while the paged side stays usable through the same sequence of events.

What makes PagedAttention interesting is that it is not a new attention. The math is identical. What changed is how the kernel addresses memory. That is also why every modern inference engine ships some variant of it: vLLM’s original 16-token blocks, SGLang’s RadixAttention with a trie of shared prefixes, TensorRT-LLM’s paged KV. The layouts differ; the underlying principle is the same one operating systems have used for sixty years.

Scheduler over kernel

09

Continuous batching

Single-sequence speedups stop mattering past a point. Production servers run many sequences in parallel, and the dominant inefficiency is no longer kernel time but slot idleness. Static batching builds a batch of N requests, runs it to completion, and only then starts the next. Because sequences have wildly different lengths (a one-word answer and a thousand-token explanation can arrive in the same batch), everyone waits for the slowest. Slot utilization on chat-style workloads under static batching is often below 30%.

Continuous batching, first shipped in Orca and now standard in vLLM, TensorRT-LLM, and SGLang, treats the batch as a fluid pool. At every decode step, any sequence that finished is removed and a queued sequence is admitted into its slot. The kernel still runs one token per active sequence per step, but the slots stay full. Five to ten times the throughput on the same hardware for typical chat traffic, with no change to the model.

The trick that makes this practical is that the kernel does not need sequences to be aligned in length. The KV cache for each sequence is addressed through its own page table; the attention kernel reads them with a gather. A new sequence’s prefill can be folded into the same iteration as ongoing decodes (“chunked prefill”), so admitting a request does not stall the batch.

continuous_batching.pypython
def continuous_batch_step(active, waiting, max_batch):
    """One iteration of a continuous-batching scheduler."""
    # 1) drop sequences that finished during the previous step
    active = [s for s in active if not s.is_done()]

    # 2) admit new sequences into any free slots
    while waiting and len(active) < max_batch:
        new_seq = waiting.pop(0)
        new_seq.prefill()           # chunked-prefill is fine here too
        active.append(new_seq)

    # 3) run one decode step for every active sequence in parallel
    for seq in active:
        seq.advance_one_token()

    return active, waiting
static vs continuous, same workload
step
0 / 60
static tokens
0
continuous tokens
0
static batching
continuous batching

Each row is a GPU batch slot. Each column is a decode step. Coloured cells emit a token; empty cells in the static version are wasted capacity while the batch waits for the longest sequence to finish.

Press play. The gaps in the static panel are the cost of waiting for the longest sequence to finish.

Continuous batching pairs naturally with paged KV. The scheduler can admit and evict sequences at block granularity without any data copy. Together with prefix sharing, the combination is what makes modern high-QPS LLM serving possible at all: most of the 2024–2026 throughput improvements at API providers come from scheduler work, not from new kernels.

More tokens per forward pass

10

Speculative decoding

Until this point every optimization has been about making one decode step cheaper. Speculative decoding asks a different question: can a single decode step produce more than one token? The serial constraint is real (token i conditions on token i-1), but it can be amortized by guessing.

The setup uses two models. A cheap draft model proposes k tokens by sampling autoregressively from the current prefix. The capable target model then evaluates the prefix followed by those k proposals in a single forward pass and checks each proposal in turn. If a proposal is accepted under the rejection-sampling rule, the chain advances. The first rejected token is replaced by a fresh sample from the target’s own distribution at that position, and the chain stops there. The critical property of the acceptance rule is that the output distribution is exactly the target’s. Quality is unchanged.

The expected number of tokens emitted per target forward pass depends on how often the draft and target agree. With a well-matched draft, acceptance rates of 60-80% are common on code and structured prose, giving 3 to 5 tokens per target call. That is an equivalent reduction in wall-clock latency since target forward passes dominate the cost. EAGLE-2 and Medusa go further by folding the draft into the target model itself, reusing its hidden states so the draft has almost no overhead and stays well-calibrated to the target.

speculative_decode.pypython
def speculative_step(draft, target, prefix, k=5):
    """One round of draft-and-verify. Returns the tokens to append."""
    # 1) cheap model proposes k tokens autoregressively
    proposals = draft.sample(prefix, k)

    # 2) target model scores prefix + proposals in a SINGLE forward pass
    logits = target.forward(prefix + proposals)

    # 3) walk left to right, accept by rejection-sampling rule
    accepted = []
    for i, tok in enumerate(proposals):
        p_target = softmax(logits[len(prefix) + i - 1])
        p_draft  = draft.prob(tok, prefix + proposals[:i])
        if random() < min(1.0, p_target[tok] / p_draft):
            accepted.append(tok)
        else:
            # resample one token from the adjusted residual distribution
            residual = (p_target - p_draft).clip(min=0)
            accepted.append(sample(residual / residual.sum()))
            break

    return accepted
# Each call to speculative_step does ONE target forward pass and returns
# between 1 and k+1 tokens depending on the acceptance rate.
draft proposes, target verifies
rounds
0
tokens emitted
0
target passes
0
tokens / call
0.00
accumulated output
prefixacceptedresampled
Thedecoderwrotethe
DRAFT
small · ~1B
0 fwd
idle
TARGET
large · ~70B
0 fwd
idle
this round’s proposals
…1
…2
…3
…4
…5
waiting for next round
target forward passes: speculative vs vanilla
speculative0 calls → 0 tokens
vanilla0 calls → 0 tokens

Target passes dominate wall clock. Draft cost is treated as negligible (EAGLE/Medusa reuse target hidden states).

Watch the small draft model run k forward passes in sequence, then the large target model verify all k tokens in a single pass. The tokens-per-call ratio at the top tells you how many tokens each expensive target call emits.

Speculative decoding shines where latency dominates: chat completions, code assistants, anywhere the user is waiting on the first few hundred tokens. It composes cleanly with everything above. The draft and target share the KV cache for the verified prefix. The target’s forward pass uses Flash Attention. Continuous batching admits new sequences alongside speculative ones. The technique does not save FLOPs in absolute terms (the target still has to evaluate every position), but it cuts wall-clock by trading compute for serial steps.

Closing

11

Where this leaves us

Across these eleven sections we walked one thread. A naive decoder does too much work, and a sequence of optimizations each remove a different kind of waste. The KV cache cut compute but cost memory. GQA and MLA brought it back down. Flash Attention rearranged HBM traffic so the kernel could run close to peak. PagedAttention turned the cache into something a multi-sequence server could actually manage. Continuous batching kept the slots full. Speculative decoding amortized the serial constraint. None of these is the main optimization, which is why vLLM, TensorRT-LLM, and SGLang ship them all at once. The 50x to 500x gap between naive PyTorch and a tuned deployment comes from the composition, not from any one trick.

Writing this page was, honestly, an adventure for me. I built it as a way of arranging my own notes so they would still make sense to me six months from now, and I shared it because the visual form happens to be easier to follow than the way I keep these ideas in my head. LLM inference optimization is one of the corners of this field that genuinely excites me. The pace is wild, and the work is being done by people who think carefully about hardware, numerics, and scheduling all at once. Reading their papers, then trying to redraw what they did in a form I could explain to a friend, that is the part I enjoy most.

I do not have a strong prediction about where this goes next. Kernels will keep approaching peak. Attention will keep being rewritten for each new GPU generation. The boundary between “model” and “serving system” will keep blurring as drafts get folded into targets, attention gets fused with norms, and KV management moves into the model itself. Long contexts will make the cache problem worse before it gets better. I will keep updating this page as new pieces land, because the pace does not seem to be slowing.