When I read “the cat sat on the mat,” my eyes don’t weight every token equally. “Cat” and “mat” pop; “the” and “on” fade. Whatever pre-attentive process the visual cortex runs, it’s content-addressable: I’m looking for nouns, and the nouns light up. Modern transformers do the same trick, but as a differentiable matrix multiply, on every layer, in parallel, on hardware that loves matrix multiplies. This is the long-form theory tour: the original derivation, the variants that mattered, and the modern frontier as of late 2025.
1. Attention as a soft lookup
Forget the formula for a second. Imagine a Python dictionary. The keys are “what’s findable”, the values are payloads, the query is what you hand the dict. A regular lookup is hard: exactly one key matches, exactly one value comes back. Now soften it. Instead of one match, give every key a similarity score against the query, normalize the scores into weights, and return the weighted average of values. That’s attention.
Concretely, with $N$ keys $k_1, \ldots, k_N \in \mathbb{R}^d$, matching values $v_1, \ldots, v_N$, and a query $q \in \mathbb{R}^d$:
$$\text{output} = \sum_{i=1}^N \alpha_i \, v_i, \qquad \alpha_i = \frac{\exp(\text{score}(q, k_i))}{\sum_j \exp(\text{score}(q, k_j))}$$
Softmax forces the $\alpha_i$ to be non-negative and sum to one. The output is a convex combination of values, biased toward the values whose keys most resemble the query. If one key dominates, the softmax saturates and the output approaches that single value. If all keys are similar, you get a near-uniform mean. The whole spectrum is differentiable. In a transformer, $K$, $V$, and $Q$ are all learned linear projections of the same hidden states (self-attention) or of two different sequences (cross-attention).
Three things — two practical, one structural.
First, gradients. Argmax is a step function: tiny query changes either flip the choice or do nothing. The gradient is zero almost everywhere and undefined where it isn’t. Softmax is smooth: every key contributes signal proportional to its weight, so the network learns where to look. Gumbel-softmax and sparsemax give sparser variants, but you still need a soft path during training.
Second, information bandwidth. Real signals live in superposition. The token “bank” needs to attend to “river” and “interest rate” simultaneously to disambiguate. Argmax forces a single winner; soft lookup lets the head pull evidence from multiple sources at once. That’s how a single attention layer simultaneously does retrieval, coreference, and syntactic agreement.
Third, expressivity per layer. With $H$ heads each averaging over $N$ tokens, an attention layer expresses a huge family of small computation graphs over the context. Hard top-1 per head would collapse most of that.
💡Why use a weighted sum rather than picking the single best-matching key? What does the soft version buy us that argmax does not?
click to reveal
2. Scaled dot-product attention
The similarity function is dot product: $\text{score}(q, k) = q \cdot k$. Why dot product, not cosine or L2? It’s fast — one matmul $QK^\top$ scores every pair. It’s expressive — bilinear, capturing both direction (which features matter) and magnitude (how confident this comparison is). And it composes well with softmax: small linear gaps become exponentially-large probability gaps, exactly the contrast you want.
Now the famous fix. If $q, k \in \mathbb{R}^d$ have zero-mean unit-variance components, $k_1, \ldots, k_N \in \mathbb{R}^d$2 is a sum of $d$ approximately-independent unit-variance terms, so $\text{Var}(q \cdot k) \approx d$ and the standard deviation grows as $\sqrt{d}$. As $d$ grows, dot products spread out, the softmax saturates onto whichever key has the largest score, and gradients through the rest vanish. The fix:
$k_1, \ldots, k_N \in \mathbb{R}^d$7
where $k_1, \ldots, k_N \in \mathbb{R}^d$8 is the per-head key dimension — typically 64 or 128 in modern models (LLaMA-3 70B has $d_\text{model} = 8192$ with 64 heads, so $v_1, \ldots, v_N$0). Dividing by $v_1, \ldots, v_N$1 keeps the score variance $\approx 1$ regardless of head dimension.
With $d_k = 128$ and unit-variance $q, k$, the raw dot product has standard deviation $\sqrt{128} \approx 11.3$. The maximum of $N$ such Gaussians (Fisher-Tippett) is around $\sqrt{2 \ln N} \cdot \sigma$; for $N = 4096$ that’s roughly $46$. After softmax, the gap to the runner-up is on the order of $46 - 40 = 6$, meaning $e^6 \approx 400\times$ probability ratio — the distribution collapses onto effectively one key.
Now the gradient. Softmax saturates: $\partial \text{softmax}_i / \partial \text{score}_j = \alpha_i (\delta_{ij} - \alpha_j)$. When one $q, k$0, every off-diagonal gradient vanishes. Worse, $q, k$1 as $q, k$2, so even the gradient through the winner dies. The whole layer is dead.
Dividing by $q, k$3 pulls score std back to $q, k$4, the largest score sits near 4, the gap is around 0.5, the softmax is in the soft regime, and gradients flow.
💡What would happen if we used $QK^\top$ without the $\sqrt{d}$ scaling? Walk through the consequences in concrete numbers for a 128-dim head.
click to reveal
A worked example, $v_1, \ldots, v_N$3, $v_1, \ldots, v_N$4
Concrete numbers anchor the algebra. Take three tokens with
$$ Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 1 \end{bmatrix}, \quad K = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 1 \end{bmatrix}, \quad V = \begin{bmatrix} 2 & 0 \\ 0 & 2 \\ 1 & 1 \end{bmatrix}. $$
Compute $S = QK^\top$ row by row. Row 0 (query $v_1, \ldots, v_N$7) dotted with each key gives $[1, 0, 1]$. Row 1 ($v_1, \ldots, v_N$9) gives $[0, 1, 1]$. Row 2 ($q \in \mathbb{R}^d$1) gives $q \in \mathbb{R}^d$2. So
$q \in \mathbb{R}^d$3
Take row 0 the rest of the way. Exponentiate: $q \in \mathbb{R}^d$4. Sum $q \in \mathbb{R}^d$5. Softmax weights $q \in \mathbb{R}^d$6. Multiply by $V$:
$$ \text{out}_0 = 0.40 \cdot [2,0] + 0.20 \cdot [0,2] + 0.40 \cdot [1,1] = [1.20, 0.80]. $$
Notice the head pulled mostly from rows 0 and 2 (the keys aligned with query $q \in \mathbb{R}^d$9) and partly from row 1, blending the values into $[1.20, 0.80]$. This single row is the entire mechanism — every transformer is doing this in parallel across batch, head, and query position.
The softmax is the part that turns “scores” into “weights that sum to 1,” and its temperature controls whether attention is sharp (one key dominates) or diffuse (uniform mean). Drag the slider:
Loading visualization…
QK-norm and cosine attention
The variance argument fixes the expected score magnitude at init. It does not fix what happens during training when $\|Q\|$ and $\|K\|$ drift. With sufficient training, query and key norms can grow unboundedly — a small fraction of heads end up producing scores in the hundreds, softmax saturates again, and FP8 / INT8 inference quantizes those heads to garbage.
The fix: L2-normalize $Q$ and $\alpha_i$0 per-head before the dot product, so the score is exactly cosine similarity bounded in $[-1, 1]$, then multiply by a learned per-head temperature $\alpha_i$2. Swin Transformer V2 introduced this for vision; recent Gemini variants and many production language models use it. The cost is one extra normalize per head per layer; the win is unconditional numerical stability across training and across precisions. If you’re planning FP8 attention from day one, QK-norm is non-optional.
3. Multi-head attention
Project $\alpha_i$3 hidden dimensions down to $\alpha_i$4 groups of $\alpha_i$5 each, run attention independently in every group, concatenate, project back up:
$\alpha_i$6
with $\alpha_i$7. Why not one big head with the full $\alpha_i$8? Because one head expresses one attention pattern per token. With $\alpha_i$9 heads, each token gets $H$ independent reads, attending to different things in different subspaces. The split is also free in flops because the $N \times N$ matmul cost dominates the inner dimension.
Visually, the $D$ feature dimensions get partitioned into $H$ contiguous blocks of $d_k = D/H$ each. Try increasing $K$5 — same total parameters, but each head gets less to work with:
Loading visualization…
Induction heads — what a head actually does
Olsson et al. (Anthropic, 2022) gave the cleanest demonstration that heads specialise to nameable circuits. Take a 2-layer attention-only model. Train it on language. Inside, you find a recurring two-head circuit:
- A previous-token head in layer 1 whose attention pattern is “attend to position $i-1$.” Its output gets written into the residual stream of position $i$, effectively copying token $V$4’s embedding into position $V$5 as a side-channel.
- An induction head in layer 2 whose query at position $V$6 attends to positions where the layer-1 side-channel matches the current token. So if the current token is $A$, and earlier in the sequence the bigram $V$8 appeared, the layer-2 head finds that earlier $A$ via the side-channel and copies $Q$0 — predicting the next token.
This is the in-context-learning copy primitive, mechanically. The same circuit shows up in real transformers and explains why even small models can learn novel made-up vocabularies just from context — they implement “if I saw $Q$1 followed by $Q$2 before, predict $Q$3 after $Q$4 now” in literal weight space. Once you’ve internalised the induction-head story, “attention as soft retrieval” stops being a metaphor.
Other named circuits the literature has found: name-mover heads (move proper-noun representations forward), successor heads (output the next item in an enumerated list), copy-suppression heads (downweight tokens that already appeared). The catalogue keeps growing; the point is that headcount is not a fudge factor — each head is potentially a distinct piece of mechanism.
No, and the reason is the variance argument from section 2 plus a capacity argument.
As $d_k$ shrinks, each head’s attention pattern gets noisier — fewer dimensions to compare with, less discriminative similarity. At $d_k = 32$ you’re computing similarity from only 32 features per pair, roughly equivalent to a 5-bit hash. Empirically, models below $d_k = 64$ start losing quality unless aggressively dropout-regularized.
The other direction matters too: at $d_k = 256$ you’re spending capacity on within-head precision rather than across-head diversity. The classic compromise — and what almost every modern model converges to — is $d_k \in \{64, 128\}$. LLaMA-3 70B, Mixtral 8x22B, GPT-3: all $d_k = 128$. The convergence is not coincidence.
💡Is more heads always better? You're scaling a 7B-param decoder with fixed $D = 4096$. Should you go from 32 heads ($d_k = 128$) to 64 ($d_k = 64$) to 128 ($d_k = 32$)?
click to reveal
4. Position encoding
Self-attention is permutation-equivariant: shuffle the input, the output shuffles. Language is not a bag of words, so position must be injected.
The original Vaswani trick was sinusoidal: $Q$5 and similar with cosine, added to the embedding. For any offset $Q$6, $\text{PE}(p + \Delta p)$ is a linear function of $Q$8 — relative positions are recoverable. In practice, networks still learn to use absolute positions, so extrapolation rarely works.
Learned absolute (BERT, GPT-2): just an nn.Embedding keyed on position. Slightly more expressive, completely incapable of going beyond training context — position 4097 was never optimized.
RoPE won. Instead of adding a position vector, rotate the query and key by a position-dependent angle. Pair up $Q$9 and rotate by $\theta_i \cdot p$ where $\theta_i = 10000^{-2i/d}$. The dot product $q_p^\top k_{p'}$ then depends only on $p - p'$, not absolute positions. No extra parameters, naturally relative, decent extrapolation with NTK or YaRN scaling tricks.
Each feature pair lives on its own little 2D plane and rotates at its own frequency. Drag the position slider — the leftmost pair (small $i$, fast frequency) sweeps quickly; the rightmost (large $i$, slow frequency) barely moves. That’s how RoPE encodes both fine-grained adjacent-token order and coarse “where in the document am I” with the same set of basis frequencies:
Loading visualization…
ALiBi drops position encoding entirely and adds a linear distance penalty $-m \cdot |i-j|$ to attention scores. Extrapolates extremely well — train at 2K, run at 16K — but bakes in a strong locality prior that may not match the task.
Modern frontier LLMs almost universally use RoPE. ALiBi survives in niche models. Sinusoidal is dead.
RoPE’s relative-position invariance is exact in the math but lossy in practice. The frequency basis is finite — $\theta_i = b^{-2i/d}$ for $b = 10000$, lowest-frequency pair completes one rotation per training context. Run inference at 8x the training length and that pair sees 8x more rotations than the model has ever experienced. OOD phase angles break attention. There’s also an absolute-position leak: even though the dot product theoretically depends only on $\text{score}(q, k) = q \cdot k$9, the model picks up signals tied to absolute position during training (the BOS-is-always-first kind of cue), and stretching shifts those signals.
Context extension, in detail
This is worth pinning down because it shows up in every model card released after 2023. Five techniques, all operating on RoPE.
Position Interpolation (Chen et al., 2023). The crudest: rescale every position by $QK^\top$0 before applying RoPE. So position 16000 in a 4K-trained model becomes “position 4000” mathematically. Every frequency is uniformly compressed. Works after a few hundred steps of fine-tuning at the longer context. The cost is brutal on local order — high-frequency components, which encode “is this token adjacent to this other token,” are also compressed, so the model’s local syntax gets noisier.
NTK-aware scaling (bloc97 on Reddit, 2023, later analysed by Su). The clean fix: don’t rescale positions, rescale the base of the RoPE frequencies. If the original base is $b = 10000$ and you want to extend by factor $QK^\top$2, set $QK^\top$3. This leaves the highest-frequency dimension nearly unchanged (so adjacent-token discrimination survives) and stretches the lowest-frequency dimensions the most (so the model’s never-seen-this-phase-angle problem is concentrated where the model is least sensitive). Ships zero-shot on many models.
YaRN (Peng et al., 2023). The current state of the art in pure RoPE extension. It splits the RoPE frequencies into three bands by wavelength relative to training context: high-frequency (wavelength much shorter than $QK^\top$4) is left alone, low-frequency (longer than $QK^\top$5) is interpolated, mid-frequency is interpolated partially with a smooth ramp. On top, it adds a single temperature term to the attention logits — necessary because longer context means more keys for the softmax to spread over, and without compensating the entropy you get systematically softer attention.
LongRoPE (Microsoft, 2024). Treats the per-frequency rescale as 2$QK^\top$6 free parameters and searches for them via evolutionary optimisation on a calibration set. Empirically beats hand-derived YaRN at extreme extensions (256K+) because the optimal rescale isn’t a simple closed form when the model’s training data has structural quirks.
Frequency basis change (newer models like Llama-3.1, Qwen-2.5). Rather than fix at $QK^\top$7, train at $QK^\top$8 or higher from the start. Higher base means lower-frequency basis, which means longer wavelengths fit inside training context naturally, which means less extrapolation pressure. This is the cleanest move — pay for it once at pretraining instead of papering over it at inference.
Position Interpolation maps every position $p \to p / s$ uniformly. So the dimension-pair with frequency $\theta_i$ now sees rotation angle $\theta_i p / s$ where it used to see $\theta_i p$. That’s identical to keeping positions but scaling frequency by $1/s$.
Now look at the highest-frequency pair, $\theta_{d/2-1} \approx b^{-(d-2)/d}$. With $d = 128$ and $b = 10000$, that’s roughly $10^{-3.97} \approx 1.07 \times 10^{-4}$ rad per position. After PI with $s = 8$, it’s $\theta_i$0 rad per position. To rotate this pair through one radian — meaningfully change its dot-product contribution — you now need $\theta_i$1 positions, where you used to need $\theta_i$2.
That’s the breakage. The high-frequency component’s job is encoding immediate order: “is this the very next token, or the one after, or three back?” Adjacent-position differences in raw RoPE are around $\theta_i$3 rad, distinguishable. After PI, they’re $\theta_i$4 rad — close to floating-point noise once you’re in BF16. The model can no longer cleanly distinguish “two tokens apart” from “three tokens apart” from “four tokens apart.”
NTK-aware scaling is the principled fix because the high-frequency pair stays at almost its original $\theta_i$5. You only stretch the low frequencies, which were under-resolved for long context anyway and so have headroom to be stretched without degradation. The asymmetry is the entire point — context extension is a high-frequency-preservation problem, not a position-rescaling problem.
💡Why does naïve linear interpolation of RoPE positions degrade local-order modeling? Walk through what it does to high-frequency components.
click to reveal
5. Pre-LN vs Post-LN
The original transformer was Post-LN: $QK^\top$9. Modern transformers — every major LLM since GPT-2 — are Pre-LN: $x_{l+1} = x_l + \text{Sublayer}(\text{LN}(x_l))$.
In Post-LN the residual stream travels through a LayerNorm before being added next layer. In Pre-LN the residual is untouched; only the sublayer’s input is normalized. This matters because of gradients. Pre-LN’s residual path is the identity, Jacobian $q, k \in \mathbb{R}^d$1, gradients flow back unchanged regardless of depth. Post-LN’s path goes through LayerNorm on every layer; stack 96 of those and Jacobians compound into exploding or vanishing gradients.
Practically: Pre-LN trains stably from random init with no warmup or trivial warmup. Post-LN needs warmup — the original used 4000 steps for a base model. Without it, gradients blow up in the first hundred steps and the model diverges. Late in training Post-LN actually has slightly higher expressivity (it forces the network to actively decide what to keep, rather than passively accumulating) showing up as ~0.05–0.15 PPL advantage — real but invisible compared to the 5%+ gap a bad warmup schedule causes. Pre-LN is strictly better unless you have unlimited compute and tuning patience.
Stabilising attention training at scale
Pre-LN gets you most of the way; the rest is a list of known failure modes and the patches that work.
Logit explosions. Even with Pre-LN, attention logits can drift to $\pm 100$ during training, especially in late layers and especially with mixed precision. Once a logit hits BF16’s saturation regime, the gradient through that head is functionally zero, the head dies, and you can lose double-digit perplexity. The fix in recent Gemini and PaLM-2 is logit soft-cap: replace the raw logit $q, k \in \mathbb{R}^d$3 with $q, k \in \mathbb{R}^d$4 for $q, k \in \mathbb{R}^d$5. Bounded, smooth, costless, and preserves the soft regime when you most need it.
Query/key norm growth. Even bounded logits don’t help if $\|Q\|, \|K\|$ themselves grow unbounded — the per-token magnitude carries through to downstream layers. QK-norm (section 2) plus a learned temperature is the structural answer. Most production models past 2024 use it.
Mu-Transfer / muP (Yang et al., 2022). Tangentially related but worth knowing. Standard Adam-style training has hyperparameters (learning rate, init scale) that depend on width, so the LR you tuned at 100M parameters is wrong at 100B. muP reparameterises so the optimal hyperparameters are width-invariant: tune at 100M, ship at 100B unchanged. The transformation involves scaling attention logits by $q, k \in \mathbb{R}^d$7 rather than $1/\sqrt{d}$ at inference (which weirdly contradicts our variance derivation, but is right under muP’s specific init regime). Production-relevant because it makes large-model HP search affordable.
Z-loss on output logits. Borrowed from PaLM. Add a small penalty $q, k \in \mathbb{R}^d$9 where $Z$ is the softmax normaliser of the output projection. Prevents the output projection logits from drifting to extreme values during long pretraining runs. Cheap insurance.
These are individually small. Compound over a 1T-token run and any one of them can be the difference between 7.2 and 6.9 final loss.
6. Causal, encoder, cross
The same mechanism wears three masks. Bidirectional self-attention (BERT, T5 encoder, ViT): every token attends everywhere; mask is all-ones. Causal self-attention (GPT, LLaMA, every decoder-only LLM): token $i$ only attends to positions $\leq i$; mask is lower-triangular, upper triangle gets $-\infty$. Cross-attention (encoder-decoder, T5, Whisper): queries from the decoder, keys and values from the encoder, no mask — the encoder represents fully-observed input.
Loading visualization…
7. The KV cache
This is the most consequential implementation detail in modern LLM serving.
Naively, generating token $t$ requires running attention over positions $k_1, \ldots, k_N \in \mathbb{R}^d$25, recomputing every previous $K$ and $V$. That’s $O(L^2)$ over a generation. But $k_1, \ldots, k_N \in \mathbb{R}^d$29 and $V$ for past tokens never change — they depend only on past hidden states, which are fixed. So cache them. After generating token $d$1, store the new $d$2 in a per-request buffer; for token $t+1$, project just the new query, concatenate with the cache. Inference becomes $O(L)$ per token.
Memory cost per request, per layer: $d$5 floats. For LLaMA-3 70B (80 layers, 8 KV heads, $d_k = 128$, FP16), per-token cost is $d$7 KB. At 32K context, 10.5 GB per request. At 128K, 42 GB. H100 HBM is 80 GB, so weights alone (140 GB FP16, 70 GB FP8) plus a couple long-context users saturates the chip. This is why long-context serving is hard. Every architectural variant after this is, in some sense, a strategy for dealing with the KV cache.
Paged attention
The naïve serving allocator gives each request a contiguous KV buffer of size $L_\text{max}$ — say 128K. A request that uses 2K wastes 126K. Across hundreds of concurrent requests of varying length, this fragmentation eats most of HBM and serving throughput collapses to a fraction of the model’s potential.
vLLM’s PagedAttention (Kwon et al., 2023) lifts the OS virtual-memory trick. Allocate KV in fixed-size pages — typically 16 tokens × all layers × all heads. Each request maintains a small page table: “my logical token positions $d$9 live in physical pages $[\text{p\_17}, \text{p\_42}, \text{p\_3}, ...]$.” Attention reads gather across pages via an indirection, costing one extra index lookup per page boundary. The win: zero fragmentation, so HBM utilisation goes from ~30% to ~95% in mixed-length workloads. Throughput goes up by the same factor.
PagedAttention also unlocks prefix sharing. If 10 requests share the same 4K-token system prompt, they can all point their first 256 page-table entries at the same physical pages, paying once for that prefix’s KV instead of 10×. Production deployments of long system prompts (think agent harnesses with tool definitions) routinely save 50%+ KV bandwidth this way.
Speculative decoding and the cache
Speculative decoding uses a small fast “draft” model to propose $\text{Var}(q \cdot k) \approx d$1 tokens at once, then runs the big “target” model in a single forward pass to verify all $k$ in parallel. The win is that the target model can score $\text{Var}(q \cdot k) \approx d$3 candidate next-tokens for the cost of one forward, rather than $k$ sequential forwards.
The reason this works on the KV-cache side: during verification the target model attends over its existing cache plus the $\text{Var}(q \cdot k) \approx d$5 proposed tokens. If positions $\text{Var}(q \cdot k) \approx d$6 are already cached, the verification forward computes $\text{Var}(q \cdot k) \approx d$7 for the $k$ new tokens, appends them to the cache, and computes attention with $k$ queries against $\sqrt{d}$0 keys. After verification, the longest accepted prefix (say $\sqrt{d}$1) is committed; the rest of the cache (positions $t+j..t+k$) is rolled back. Roll-back is cheap — just decrement the page-table length, no actual writes.
The cache makes speculative decoding free in the limit: the target model’s compute per accepted token approaches one matmul, regardless of how many drafts get rejected, because rejected-position $K, V$ entries are reused on the retry. Without a cache the math doesn’t work — every rejection would mean recomputing prefix attention.
8. GQA, MQA, and MLA
KV cache scales with $\sqrt{d}$4, not $\sqrt{d}$5. There’s no architectural reason these have to be equal. MQA (Multi-Query Attention, PaLM 2019) takes the extreme: $\sqrt{d}$6. All query heads share a single $\sqrt{d}$7 pair. Cache shrinks by $\text{num\_heads}\times$. The problem is quality: heads can no longer specialize in different content subspaces, only in different query projections. MQA-from-scratch loses ~0.5 PPL on language modeling and more on harder tasks.
GQA (Grouped-Query Attention, 2023, used in LLaMA-2 70B onward) is the compromise. Pick a middle group size: LLaMA-2 70B uses 64 query heads and 8 KV heads, each KV head shared by 8 queries. Cache shrinks by $8\times$, quality loss under 0.1 PPL. GQA is the default for every modern frontier model. Group size 8 is the de facto sweet spot.
Two reasons MQA underperforms even when you don’t care about pure quality.
First, the perplexity hit is not uniform. MQA loses more on tasks that benefit from rich attention patterns — multi-hop reasoning, code, long-context retrieval. Old MQA papers compared on language modeling and the gap looked small. On modern benchmarks (MMLU, HumanEval, RAG), the gap is closer to 2–3% absolute, the difference between competitive and not. GQA-8 closes almost all of that for almost no extra cache.
Second, the cache savings flatten out. MHA → GQA-8 is $8\times$. GQA-8 → MQA-1 is another $8\times$. But the absolute second-step savings are tiny because the cache is already small. You also lose the parallelism benefit of group sharding across 8 GPUs in distributed serving.
Third, MLA (Multi-head Latent Attention, DeepSeek-V2/V3) achieves better KV compression than MQA and better quality than MHA by projecting K, V into a low-rank latent space and reconstructing on demand. If you’re cache-constrained today, MLA is the move.
💡Why not always go to MQA? In a serving environment where throughput is everything and 0.5 PPL is "fine," wouldn't $\text{num\_kv\_heads} = 1$ pay off?
click to reveal
MLA: factorising rather than reducing
GQA and MQA reduce the number of KV heads — fewer copies, less cache. Multi-head Latent Attention (DeepSeek-V2, refined in V3) takes a structurally different route: keep all KV heads conceptually, but compress what you store.
Standard MHA caches $d$0 per token. MLA introduces a learned low-rank projection $d$1 where $d_c$ is small — DeepSeek-V3 uses $d$3 across all heads. At cache time, store only the latent $d$4 per token. At attention time, project back up to per-head $d$5 via $d$6.
The cache savings are significant. Compare DeepSeek-V3 (latent dim 512 per token per layer, FP16, 61 layers) versus a 70B-class GQA-8 ($8 \cdot 128 \cdot 2 = 2$ KB per layer per token, 80 layers). MLA stores roughly 40% of GQA’s cache, and dramatically less than MHA.
The subtle bit: at inference, $W^K_\text{up}$ can be merged into $W^Q$ (compute $Q' = W^Q \cdot W^K_\text{up}$ once at load time), and $W^V_\text{up}$ folds into $k_1, \ldots, k_N \in \mathbb{R}^d$72. So the back-projection isn’t a runtime cost — it’s absorbed into the existing query and output matrices. You run MLA as if it were vanilla attention against a low-rank cache.
Quality-wise: DeepSeek’s ablations show MLA matching or slightly beating MHA at equal parameter count, while paying a fraction of GQA’s cache. The catch is non-trivial implementation — RoPE has to be split between a “rotated” part of $K$ that bypasses the latent compression and an “un-rotated” part that goes through it, because rotating before projecting doesn’t commute with rotating after. Engineering effort is real.
Direct comparison: GQA reduces KV head count, MQA reduces it to one, MLA factorises the KV computation. MLA achieves better quality at similar or smaller cache size than MQA, and matches MHA quality at smaller cache than GQA. If you’re designing a new model in 2025 and have the engineering budget, MLA is the front of the pack.
The back-projection is real but it’s not paid at attention time. $W^K_\text{up}$ folds into $W^Q$ (you pre-multiply once at load time) and $W^V_\text{up}$ folds into $W^O$. Per-token attention compute is identical to a small-cache GQA, no extra matmul.
Where the cost shows up is storage: $W^{KV}_\text{down}, W^K_\text{up}, W^V_\text{up}$ have to live in HBM. In DeepSeek-V3 these add roughly 1–2% to total parameter count. You spend a tiny bit more weight memory and save a lot of cache memory.
The cache win materialises whenever (cache memory saved) × (decoding tokens-per-second uplift from less cache pressure) exceeds (small static weight overhead) × (any additional compute for fold-in). For decoder-only LLM inference at production batch sizes, this is essentially always favourable past a few thousand context tokens.
Where MLA doesn’t help: very short contexts where weights dominate cache anyway, or training-time compute, where the fold-in trick doesn’t apply because the up-projections need their own gradients during the backward pass. In practice, you see MLA in inference-optimised models, not pure-research baselines.
💡MLA stores 512 dims of latent per token. GQA-8 stores 8 KV heads × $d_k$ = 8 × 128 = 1024 dims per token. MLA looks like a 2× cache win — but the back-projection is a matmul. When does MLA actually save total inference cost vs GQA?
click to reveal
9. Sliding window and local attention
Full attention is $k_1, \ldots, k_N \in \mathbb{R}^d$74 in compute. Sliding window (Longformer 2020, Mistral 7B 2023) drops this to $k_1, \ldots, k_N \in \mathbb{R}^d$75 by restricting each token to a window of size $k_1, \ldots, k_N \in \mathbb{R}^d$76 to its left. At first glance catastrophic — token 10000 can’t attend to token 0! — but layer compounding saves it. With $W = 4096$ over 32 layers, the effective receptive field is $W \cdot L = 128K$, analogous to a CNN with kernel $W$ and depth $k_1, \ldots, k_N \in \mathbb{R}^d$80.
Whether the network can use that propagation is the empirical question. Mostly yes for local-pattern tasks (next-token prediction, code, math); mostly no for sharp long-range retrieval (NEEDLE-IN-HAYSTACK, multi-document QA).
Click between the four masks below to see what each one actually computes — same grid, four different rules for which (query, key) cells are visible. The “block-sparse” pattern shown here is local + always-attend on the first two tokens, which is roughly what StreamingLLM’s attention-sink fix produces:
Loading visualization…
Attention sinks
Xiao et al. (StreamingLLM, 2023) noticed something unsettling. Take any pretrained LLM, plot per-head attention weights across the full context, and you find that the first 1–4 tokens absorb a wildly disproportionate share of mass — sometimes 30–50% of total softmax weight on a head — regardless of what those tokens actually are. It’s not that BOS or “the” is uniquely informative. It’s that the position attracts mass.
The mechanistic explanation is structural. Softmax forces $\sum_i \alpha_i = 1$ on every row, even when the head genuinely has nothing useful to attend to. The model has to put that mass somewhere. During training, the early tokens are visible to every later position — they’re always inside any reasonable receptive field — so the gradient pressure to use them as a “garbage dump” for unused softmax mass is consistent across the whole sequence. The model learns to route excess mass there because it’s the one stable target.
Practical consequence: in a sliding-window streaming setting, the moment your window slides past the original first-tokens, every head loses its dump. Attention mass redistributes onto whichever tokens are visible — including content tokens that don’t deserve it — and generation quality collapses within a few tokens. The fix from StreamingLLM is to keep the first $k_1, \ldots, k_N \in \mathbb{R}^d$82 tokens permanently in the window even as the rest slides. Cheap, reliable, ships in essentially every long-context serving system now.
More recent work has tried to address sinks structurally. Adding a learned “register” key/value pair per head that doesn’t correspond to any input token gives the head an explicit place to dump unused mass, and reduces sink concentration on real tokens. Variants like “softmax with a leaky residual” allow weights to sum to less than 1, in principle eliminating the dump-pressure entirely — but they don’t quite work in practice (see the question below). Sinks remain a structural property of the standard softmax that production systems work around rather than fix.
You can write one easily — append an extra “null logit” $z_0 = 0$ to each row before softmax, then drop the resulting first weight. Now the real token weights sum to less than 1 and the model has an explicit “attend to nothing” option. So why isn’t this the default?
Two reasons. First, the variants don’t actually eliminate sink behaviour, they just relocate it. The model still wants somewhere to dump unused mass — but instead of a stable absorbed-by-first-tokens dump, you get a moving dump that shifts across content tokens, which is worse for interpretability and quality. The pressure to concentrate excess weight is structural to attention as a mechanism, not specifically to softmax-summing-to-1.
Second, you lose the bound that softmax gives you on output magnitude. With $\sum \alpha_i = 1$ and bounded $V$, the output is bounded by $\max\|v_i\|$. Drop that constraint and you have to add explicit regularisation to keep activations from drifting unboundedly during training, which historically causes its own instability. The cure is worse than the disease.
The pragmatic answer is: sinks are real but they’re a known, well-located phenomenon. Keeping a few absolute-position tokens permanently visible is cheaper than re-architecting the softmax. Future work will probably eliminate sinks structurally, but it’ll come bundled with other changes (linear attention, register tokens, adaptive temperature) rather than a single drop-in softmax replacement.
💡Attention sinks are an emergent property of softmax requiring weights sum to 1. Why couldn't a softmax variant that allows weights to sum to less than 1 plausibly fix this?
click to reveal
10. Sparse attention
Sliding window is a fixed pattern. The natural next step: let the model learn the pattern. The lineage:
BigBird (2020): mix sliding window (local), random attention (a few random tokens per query), and global tokens (special tokens everyone attends to and that attend to everyone). Provably as expressive as full attention given enough random tokens.
Reformer (2020): locality-sensitive hashing buckets queries with similar keys; only attend within bucket. Content-based sparsity, theoretically beautiful, in practice training-unstable and outperformed by simpler patterns.
Native Sparse Attention (NSA, DeepSeek 2024) is the current player. Three parallel branches per query: a local window, a coarse-grained mean-pooled summary of the full context, and the top-$k$ tokens selected by a learned scorer. A learned gate mixes the three. NSA is trainable end-to-end (unlike LSH), GPU-friendly because top-$k_1, \ldots, k_N \in \mathbb{R}^d$84 selection happens block-wise, and shows competitive-or-better quality at long context with substantially less compute. The novelty over older sparse work is natively trainable — sparsity emerges from training, isn’t a hyperparameter.
Common to every sparse scheme: keep some “always-attend” tokens. CLS in BERT, sinks in StreamingLLM, global tokens in BigBird, anchors in NSA. Without a designated dump, the attention distribution misbehaves on edge cases.
11. Linear attention
Replace softmax with a kernel that factorizes:
$k_1, \ldots, k_N \in \mathbb{R}^d$85
By associativity, compute $k_1, \ldots, k_N \in \mathbb{R}^d$86 first — a $d \times d$ matrix independent of the query. Then $\phi(Q) \cdot (\phi(K)^\top V)$ is $O(N \cdot d^2)$ instead of $d_\text{model} = 8192$0. Linear in sequence length.
- Performer (2020): random feature map that unbiased-approximates softmax. Mathematically clean, in practice high variance costs noticeable quality.
- Linformer (2020): project keys and values from length $N$ to $k$ before attention. Loses autoregressive structure (the projection mixes future and past) — encoders only.
- Mamba / S4 / RetNet: state-space models. Linear recurrences with structured state matrices; squint and they’re linear attention with a particular causal kernel. Inference is $d_\text{model} = 8192$3 per token — no KV cache, just a state vector.
- Kimi Linear (Moonshot 2024-2025): the recent serious challenger. Gated DeltaNet variant — associative-memory recurrence with a cosine-similarity gate that decides per-token whether to write to or read from the linear state. The cosine gate is the trick that gives competitive quality; older linear attempts had gates either too soft (bleed) or too hard (forget). Kimi pairs linear with a few full-attention layers as anchors, keeping long-range exact recall when needed.
Why has linear been hard? Softmax does two things: similarity weighting (kernels approximate this) and normalization that bounds output magnitude (kernel methods break this). Without normalization, training is unstable. Modern variants (Mamba’s selective SSM, Kimi’s gated DeltaNet) add explicit normalization that recovers the stability softmax provided for free.
Three things, in increasing severity.
First, exact recall. Softmax can put nearly all its mass on one specific past token — for copying, induction, retrieval. Linear attention approximates this only up to the rank of its feature map. With $d_k = 128$, you have a 128-dim memory per layer per head; a context with 10000 distinct tokens cannot all be perfectly remembered. This shows up as bad NEEDLE-IN-HAYSTACK.
Second, kernel approximation variance (Performer-style). Random features are unbiased but variance grows with sequence length and softmax sharpness. Sharp patterns are exactly the ones the model wants for precise operations — and exactly where the kernel is worst.
Third, gradient pathology. Linear-attention recurrences sum past contributions; without softmax-normalized weighting, magnitudes drift. Modern gated variants fix this, but the fix is architecture-specific.
In exchange you get $O(1)$ generation and $O(N)$ training, so you can train and serve at lengths softmax can’t touch. Great for fluency-over-precision tasks (long-form generation, summarization). Bad for sharp recall (RAG, code retrieval, multi-hop QA), which is why production systems use softmax or hybrids.
💡What does linear attention give up versus softmax attention, and where does the quality gap come from?
click to reveal
12. FlashAttention
Up to here we’ve been changing algorithms. FlashAttention is none of those — it computes the exact same softmax attention, but changes the memory access pattern.
Core observation (Dao et al. 2022): on modern GPUs, attention’s bottleneck is HBM bandwidth, not flops. Standard attention materialises the full $d_\text{model} = 8192$4 attention matrix in HBM, writing $N^2$ floats out and reading them back to compute softmax. For $N = 4096$ FP16, that’s 32 MB written and 32 MB read per layer per head — and HBM bandwidth is the binding constraint.
The online softmax recurrence
The whole algorithmic insight is one trick: you can compute softmax incrementally while keeping the result numerically exact. Process keys in blocks $j = 1, 2, \ldots$. For each block, compute scores $s^{(j)} \in \mathbb{R}^{B_q \times B_k}$ against the current $d_\text{model} = 8192$9 tile. Maintain two row-wise running statistics — current max $v_1, \ldots, v_N$00 and current denominator $\ell^{(j)}$ — and a running output $O^{(j)}$.
The recurrence is:
$v_1, \ldots, v_N$03
And the output rescale:
$v_1, \ldots, v_N$04
At the end of the inner loop, divide once: $v_1, \ldots, v_N$05. The factor $v_1, \ldots, v_N$06 rescales the previously-accumulated $v_1, \ldots, v_N$07 and $v_1, \ldots, v_N$08 when a new block raises the running max — that’s the trick that keeps everything numerically stable as the max grows.
Because all per-block work fits in SRAM and the only HBM IO is one read of $v_1, \ldots, v_N$09 and one write of the final $O$, you never materialise the $v_1, \ldots, v_N$11 score matrix anywhere. HBM IO drops from $O(N^2)$ to $O(N \cdot d)$. At $N = 16384$ and $v_1, \ldots, v_N$15, that’s a $128\times$ reduction in attention bandwidth. Exact output — same numbers as full materialisation, modulo floating-point reordering.
Implementation evolution
FlashAttention-2 (2023) reorganizes work partition: outer loop over $v_1, \ldots, v_N$17 blocks (one per warp group), inner over $v_1, \ldots, v_N$18, reducing inter-warp synchronization. About 2× faster than V1.
FlashAttention-3 (2024) is Hopper-specific: warp specialization (some warps load via TMA while others compute), async pipelined matmul, FP8 for QK^T. On H100, FA-3 hits ~75% of theoretical TFLOPS versus ~35% for FA-2. Cost: H100-only.
Two practical caveats. Very short sequences ($v_1, \ldots, v_N$19) — the $\approx 1$0 matrix fits in cache anyway, and the tile-loop overhead dominates; standard attention is competitive (rarely matters because short attention is negligible cost overall). And gradient checkpointing combined with FlashAttention: FA recomputes attention on the backward pass to avoid storing it; layer-level checkpointing on top means you recompute the forward and recompute attention within it — quadratic recomputation, backward becomes 4× slower than necessary. Disable layer-level checkpointing on FA layers and rely on FA’s internal checkpointing.
Because “exact” hides which attention pattern you’re computing. The online softmax recurrence assumes a specific score function (scaled dot-product), a specific masking pattern (full or causal), and a specific reduction (softmax over keys). Anything that breaks those assumptions needs a new kernel.
Sliding window with sinks needs a different mask layout — supported in FA-2.5 but each variant is a Triton/CUDA kernel someone has to write and test. ALiBi adds a bias per (i, j) pair — supported. Differential attention (section 13) needs two attention computations subtracted — you can run two FA passes, but the optimal kernel fuses them. NSA’s top-k selection is structurally incompatible — top-k is a sort, FA is a streaming reduction, you can’t fuse them into one tile loop without giving up the IO savings.
Production frameworks like xFormers and FlashInfer ship dozens of kernel variants for exactly this reason — researchers invent new attention patterns faster than kernels can be written. If you’re using a vanilla pattern, FA is the move. If you’re inventing a new pattern, you may end up materialising attention in HBM until someone writes the kernel — which can take six months from paper to merged PR.
There’s also the precision question. FA-3’s FP8 path is fast on Hopper but FP8 attention has its own quirks — without QK-norm, the scaled scores can overflow FP8’s narrow range. Some teams stick with FA-2 BF16 because it’s predictable, even though FA-3 is faster on paper.
💡FlashAttention's online softmax is exact — same numerical answer as full materialisation. Why isn't every attention implementation just FlashAttention?
click to reveal
13. Differential and exclusive: small tweaks, big effects
A handful of recent results show you don’t always need a fundamentally new mechanism — sometimes a small architectural tweak has an outsized effect on a specific capability.
Differential Transformer (Microsoft, 2024). Standard softmax puts non-trivial mass on irrelevant tokens, especially in long context. This “attention noise” hurts retrieval — even a small uniform leak across 32767 tokens dilutes the answer at token 47. The fix: compute two softmax attention outputs and return $\approx 1$1 where $\lambda$ is learned (per-layer or per-head). The intuition is differential amplifiers — two correlated-noise signals subtracted cancel the common-mode noise. With $\approx 1$3 in practice, the model learns $A_2$ as a “noise floor” estimate. Substantially better long-context retrieval at the same parameter count, slightly worse short-text PPL, ~2× attention compute.
Exclusive self-attention. In standard self-attention, $\approx 1$5 is typically the largest weight in row $\approx 1$6 (because $q_i \cdot k_i \approx \|q_i\|^2$). This self-loop carries no information — the residual already provides “self.” When trying to retrieve a far-away answer, diagonal mass competes with off-diagonal mass that actually carries information. Fix: zero out the diagonal before softmax. Modest but consistent NEEDLE-IN-HAYSTACK gains, no cost on standard short-context benchmarks.
The general lesson: attention’s behavior in the long-tail of the distribution — small masses, edge cases, sinks — turns out to matter more than the mainline arithmetic for tasks that depend on those edges. The next few years of attention research are probably a series of small tweaks like these.
14. What I’d use today
If you’re building a decoder-only LLM in 2025, the architecture is essentially solved at the granularity of “what knobs to turn.” Pre-LN with RMSNorm. RoPE for position with NTK or YaRN extension if you want context beyond training length. GQA group-8, or MLA if you’re cache-constrained and willing to do the engineering. FlashAttention-2 as the kernel, FA-3 if you’re on Hopper. Sliding window for cheap local long context, full attention for retrieval, hybrid (Gemma-2 style) for both. NSA or Kimi Linear hybrid past 256K context.
The honest summary: most of the win in modern LLMs has come from training data and scale, not architecture. Switching MHA to GQA, adding RoPE, swapping in FlashAttention — each is a small percentage win. Doubling tokens or going from 7B to 70B parameters is a big percentage win. The best architecture choices are the ones that let you train more efficiently and serve at lower cost. Pick the variants that unblock your scale, then go scale.