hard primitives

NNX MHA With KV Cache

Why this matters

Autoregressive generation produces tokens one at a time, each one needing to attend over all previous tokens. The naive thing — recompute K and V for the full prefix every step — costs O(T^2) to generate T tokens. The KV cache stores K and V for past tokens once, so each new step only computes K/V for the single new token and reads the rest: O(T) for the whole sequence.

KV caching is the single largest inference-time optimization in LLM serving — it’s why your local LLM can produce hundreds of tokens per second despite the architectural cost being quadratic.

In Linen this required nn.MultiHeadDotProductAttention(decode=True) with mutable=["cache"] ceremony at apply time. In nnx the cache buffers are just nnx.Cache variables on the module — write to .value and you’ve mutated them.

The cache

Pre-allocate two buffers of shape (max_len, num_heads, head_dim) — one for K, one for V. Plus a scalar position counter:

self.k_cache = nnx.Cache(jnp.zeros((max_len, H, head_dim)))
self.v_cache = nnx.Cache(jnp.zeros((max_len, H, head_dim)))
self.pos     = nnx.Cache(jnp.zeros((), dtype=jnp.int32))

nnx.Cache is a built-in Variable subclass intended for inference state. It behaves identically to nnx.Variable (mutate via .value) but is filtered separately from Param and BatchStat — useful when you want to swap caches between sessions without touching weights.

Decoding step

For one new token x_token:

  1. Project Q, K, V for just this token: shape (H, head_dim) each.
  2. Write the new K and V into slot pos of their cache buffers (in place via .at[idx].set(...)).
  3. Increment pos.
  4. Run SDPA: query is the new q; keys/values are the populated portion of the cache (positions 0..pos-1).

The trick is the masking. Since the buffers are pre-zero-padded to max_len, slots >= pos have undefined K/V (zeros, but still in the dot product). Mask them out with a position-based check:

valid = jnp.arange(max_len) <= idx        # (max_len,)
scores = jnp.where(valid, scores, -1e9)

where idx = pos.value is the position we JUST wrote.

Worked sketch

class CachedMHA(nnx.Module):
    def __init__(self, d_model, num_heads, max_len, rngs):
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.max_len = max_len
        self.q_proj = nnx.Linear(d_model, d_model, rngs=rngs)
        self.k_proj = nnx.Linear(d_model, d_model, rngs=rngs)
        self.v_proj = nnx.Linear(d_model, d_model, rngs=rngs)
        self.out_proj = nnx.Linear(d_model, d_model, rngs=rngs)
        self.k_cache = nnx.Cache(jnp.zeros((max_len, num_heads, self.head_dim)))
        self.v_cache = nnx.Cache(jnp.zeros((max_len, num_heads, self.head_dim)))
        self.pos = nnx.Cache(jnp.zeros((), dtype=jnp.int32))

    def step(self, x_token):
        H, Dh = self.num_heads, self.head_dim
        q = self.q_proj(x_token).reshape(H, Dh)
        k = self.k_proj(x_token).reshape(H, Dh)
        v = self.v_proj(x_token).reshape(H, Dh)
        idx = self.pos.value
        self.k_cache.value = self.k_cache.value.at[idx].set(k)
        self.v_cache.value = self.v_cache.value.at[idx].set(v)
        self.pos.value = idx + 1
        valid = jnp.arange(self.max_len) <= idx
        K = self.k_cache.value.transpose(1, 0, 2)   # (H, max_len, Dh)
        V = self.v_cache.value.transpose(1, 0, 2)
        scores = jnp.einsum('hd,hkd->hk', q, K) / jnp.sqrt(Dh)
        scores = jnp.where(valid, scores, -1e9)
        weights = jax.nn.softmax(scores, axis=-1)
        per_head = jnp.einsum('hk,hkd->hd', weights, V)
        return self.out_proj(per_head.reshape(H * Dh))

To produce the attention output for x_new given a history, you step the module through every history token first to populate the cache, then call it once more on x_new:

for t in range(T_history):
    _ = model.step(x_history[t])
out = model.step(x_new)

The output of the new-token step is identical (up to numerics) to what you’d get from running causal self-attention over concat(x_history, x_new) and reading the last row.

What nnx makes easy

The mutation self.k_cache.value = ... is just an attribute write. No mutable=["cache"] at apply time, no return-tuple of updated state, no separate params and cache collections. The module is the state container.

Common pitfalls

  • Forgetting to mask invalid slots. Pre-zeroed slots still participate in the dot product. Without valid masking, the first step would attend to max_len - 1 zero slots, which would pull the output toward zero.
  • Increment before write. Bumping pos first puts the new K/V in the wrong slot. Write at idx = pos.value, then increment.
  • Using at[idx].set(k) on the unwrapped buffer. Always go through .value. self.k_cache.at[idx] won’t compile.
  • Wrong valid direction. jnp.arange(max_len) <= idx is what you want (positions 0..idx, inclusive of the slot just written). < idx skips the new token, < pos.value AFTER increment works but is fragile to refactor.

Problem

Write mha_decode_step(seed, x_new, x_history, num_heads, d_model):

  1. Define CachedMHA(nnx.Module) with four nnx.Linear projections, two nnx.Cache buffers (max_len, H, head_dim), and a scalar nnx.Cache pos counter.
  2. step(x_token): project, write to pos slot of each cache, increment pos, mask invalid slots, run SDPA, out_proj.
  3. In the entry function: build with max_len = T_history + 1, step through every history token to populate, then call once on x_new. Return the new step’s output flattened.

Inputs:

  • seed: int (passed as float).
  • x_new: 1-D (d_model,) — the single new token.
  • x_history: 2-D (T_history, d_model) — prefix tokens.
  • num_heads, d_model: ints (passed as floats).

Output: 1-D (d_model,) — attention output for x_new.

Hints

flax nnx attention kv-cache inference transformers

Sign in to attempt this problem and view the solution.