hard primitives

Multi-Head Attention with KV Cache

Why this matters

Autoregressive generation is the slowest part of LLM serving. Naive decoding re-computes all T Q/K/V projections at every step:

  • Step 1: process token 1, produce token 2.
  • Step 2: process tokens 1, 2, produce token 3.
  • Step 3: process tokens 1, 2, 3, produce token 4.
  • Step T: process all T tokens.

Total work: O(T²) projections + O(T³) attention. Catastrophic.

The fix: a KV cache. K and V for past tokens never change once computed (with causal masking, future tokens never affect them), so cache them. At step t:

  1. Compute Q, K, V for ONLY the new token t.
  2. Append the new K, V to the cache.
  3. Compute attention: Q (single position) against the cache (all past).
  4. Output one new vector. Repeat.

Per-step work drops to O(T) attention + O(1) projections. Total decode is O(T²), no longer O(T³).

How Flax does it

nn.MultiHeadDotProductAttention(decode=True) activates cached-decoding mode. When you init the module on a “full-length” tensor, Flax allocates:

  • cached_key: shape (T, H, D/H), all zeros initially.
  • cached_value: shape (T, H, D/H), all zeros.
  • cache_index: scalar int, starts at 0.

These live under the cache collection in the variables dict — separate from params. Each call to apply(..., mutable=['cache']) updates cache_index and writes the new K/V at that slot.

Init pattern

attn = nn.MultiHeadDotProductAttention(num_heads=H, qkv_features=D, decode=True)
# Run init on a full-length context to allocate cache slots.
variables = attn.init(rng, x)        # x has shape (T_max, D)
# variables == {'params': ..., 'cache': {'cached_key': ..., ...}}

One decode step

new_token = x[0:1]                   # shape (1, D)
out, _ = attn.apply(variables, new_token, mutable=['cache'])
# out has shape (1, D) — one new output token's representation.

mutable=['cache'] is REQUIRED — Flax variables are immutable by default. Listing ‘cache’ opts that collection in for in-place updates.

Internally Flax also injects a causal mask automatically when decode=True — the cache index tracks the current position so K/V beyond cache_index are ignored.

Why decode-time attention sees the right shape

The Q is shape (1, D), K and V come from the cache (T_max, D), score matrix is (1, T_max). With the cache index gate, only valid cached positions contribute. Output is (1, D_out).

Common pitfalls

  • Forgetting mutable=['cache'] — Flax raises a “modifying cache without mutable” error.
  • Using decode=False for init — without decode=True at init, no cache is allocated. The construct-time flag is what triggers it.
  • Mismatched cache shape vs new tokeninit runs on the full-length context (e.g., (T_max, D)); APPLY runs on a single step (1, D). The cache is sized once at init.

Problem

Implement mha_with_cache(seed, x, num_heads, qkv_features):

  1. Build nn.MultiHeadDotProductAttention(num_heads=H, qkv_features=D, decode=True).
  2. variables = attn.init(rng, x)x is the full-length context that sizes the cache.
  3. out, _ = attn.apply(variables, x[0:1], mutable=['cache']) — run ONE step on the first token.
  4. Return out.reshape(-1).

Since this is the first decode step (cache_index = 0), the output should match what causal MHA produces for position 0 (which is just self-attention on token 0).

Inputs:

  • seed: int.
  • x: 2-D (T, D_in) — full-length context that sizes the KV cache.
  • num_heads, qkv_features: ints.

Output: 1-D, the one-step output flattened. Length = D_in.

Hints

flax attention kv-cache decoding

Sign in to attempt this problem and view the solution.