We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Multi-Head Attention with KV Cache
Why this matters
Autoregressive generation is the slowest part of LLM serving. Naive
decoding re-computes all T Q/K/V projections at every step:
- Step 1: process token 1, produce token 2.
- Step 2: process tokens 1, 2, produce token 3.
- Step 3: process tokens 1, 2, 3, produce token 4.
- …
- Step T: process all T tokens.
Total work: O(T²) projections + O(T³) attention. Catastrophic.
The fix: a KV cache. K and V for past tokens never change once
computed (with causal masking, future tokens never affect them), so
cache them. At step t:
-
Compute Q, K, V for ONLY the new token
t. - Append the new K, V to the cache.
- Compute attention: Q (single position) against the cache (all past).
- Output one new vector. Repeat.
Per-step work drops to O(T) attention + O(1) projections. Total
decode is O(T²), no longer O(T³).
How Flax does it
nn.MultiHeadDotProductAttention(decode=True) activates cached-decoding
mode. When you init the module on a “full-length” tensor, Flax allocates:
-
cached_key: shape(T, H, D/H), all zeros initially. -
cached_value: shape(T, H, D/H), all zeros. -
cache_index: scalar int, starts at 0.
These live under the cache collection in the variables dict — separate
from params. Each call to apply(..., mutable=['cache']) updates
cache_index and writes the new K/V at that slot.
Init pattern
attn = nn.MultiHeadDotProductAttention(num_heads=H, qkv_features=D, decode=True)
# Run init on a full-length context to allocate cache slots.
variables = attn.init(rng, x) # x has shape (T_max, D)
# variables == {'params': ..., 'cache': {'cached_key': ..., ...}}
One decode step
new_token = x[0:1] # shape (1, D)
out, _ = attn.apply(variables, new_token, mutable=['cache'])
# out has shape (1, D) — one new output token's representation.
mutable=['cache'] is REQUIRED — Flax variables are immutable by
default. Listing ‘cache’ opts that collection in for in-place updates.
Internally Flax also injects a causal mask automatically when
decode=True — the cache index tracks the current position so K/V
beyond cache_index are ignored.
Why decode-time attention sees the right shape
The Q is shape (1, D), K and V come from the cache (T_max, D),
score matrix is (1, T_max). With the cache index gate, only valid
cached positions contribute. Output is (1, D_out).
Common pitfalls
-
Forgetting
mutable=['cache']— Flax raises a “modifying cache without mutable” error. -
Using
decode=Falsefor init — withoutdecode=Trueat init, no cache is allocated. The construct-time flag is what triggers it. -
Mismatched cache shape vs new token —
initruns on the full-length context (e.g.,(T_max, D)); APPLY runs on a single step(1, D). The cache is sized once at init.
Problem
Implement mha_with_cache(seed, x, num_heads, qkv_features):
-
Build
nn.MultiHeadDotProductAttention(num_heads=H, qkv_features=D, decode=True). -
variables = attn.init(rng, x)—xis the full-length context that sizes the cache. -
out, _ = attn.apply(variables, x[0:1], mutable=['cache'])— run ONE step on the first token. -
Return
out.reshape(-1).
Since this is the first decode step (cache_index = 0), the output should match what causal MHA produces for position 0 (which is just self-attention on token 0).
Inputs:
-
seed: int. -
x: 2-D(T, D_in)— full-length context that sizes the KV cache. -
num_heads,qkv_features: ints.
Output: 1-D, the one-step output flattened. Length = D_in.
Hints
Sign in to attempt this problem and view the solution.