We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX MHA With KV Cache
Why this matters
Autoregressive generation produces tokens one at a time, each one
needing to attend over all previous tokens. The naive thing — recompute
K and V for the full prefix every step — costs O(T^2) to generate T
tokens. The KV cache stores K and V for past tokens once, so each new
step only computes K/V for the single new token and reads the rest:
O(T) for the whole sequence.
KV caching is the single largest inference-time optimization in LLM serving — it’s why your local LLM can produce hundreds of tokens per second despite the architectural cost being quadratic.
In Linen this required nn.MultiHeadDotProductAttention(decode=True)
with mutable=["cache"] ceremony at apply time. In nnx the cache
buffers are just nnx.Cache variables on the module — write to
.value and you’ve mutated them.
The cache
Pre-allocate two buffers of shape (max_len, num_heads, head_dim) —
one for K, one for V. Plus a scalar position counter:
self.k_cache = nnx.Cache(jnp.zeros((max_len, H, head_dim)))
self.v_cache = nnx.Cache(jnp.zeros((max_len, H, head_dim)))
self.pos = nnx.Cache(jnp.zeros((), dtype=jnp.int32))
nnx.Cache is a built-in Variable subclass intended for inference
state. It behaves identically to nnx.Variable (mutate via .value)
but is filtered separately from Param and BatchStat — useful when
you want to swap caches between sessions without touching weights.
Decoding step
For one new token x_token:
-
Project Q, K, V for just this token: shape
(H, head_dim)each. -
Write the new K and V into slot
posof their cache buffers (in place via.at[idx].set(...)). -
Increment
pos. -
Run SDPA: query is the new
q; keys/values are the populated portion of the cache (positions0..pos-1).
The trick is the masking. Since the buffers are pre-zero-padded to
max_len, slots >= pos have undefined K/V (zeros, but still in
the dot product). Mask them out with a position-based check:
valid = jnp.arange(max_len) <= idx # (max_len,)
scores = jnp.where(valid, scores, -1e9)
where idx = pos.value is the position we JUST wrote.
Worked sketch
class CachedMHA(nnx.Module):
def __init__(self, d_model, num_heads, max_len, rngs):
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.max_len = max_len
self.q_proj = nnx.Linear(d_model, d_model, rngs=rngs)
self.k_proj = nnx.Linear(d_model, d_model, rngs=rngs)
self.v_proj = nnx.Linear(d_model, d_model, rngs=rngs)
self.out_proj = nnx.Linear(d_model, d_model, rngs=rngs)
self.k_cache = nnx.Cache(jnp.zeros((max_len, num_heads, self.head_dim)))
self.v_cache = nnx.Cache(jnp.zeros((max_len, num_heads, self.head_dim)))
self.pos = nnx.Cache(jnp.zeros((), dtype=jnp.int32))
def step(self, x_token):
H, Dh = self.num_heads, self.head_dim
q = self.q_proj(x_token).reshape(H, Dh)
k = self.k_proj(x_token).reshape(H, Dh)
v = self.v_proj(x_token).reshape(H, Dh)
idx = self.pos.value
self.k_cache.value = self.k_cache.value.at[idx].set(k)
self.v_cache.value = self.v_cache.value.at[idx].set(v)
self.pos.value = idx + 1
valid = jnp.arange(self.max_len) <= idx
K = self.k_cache.value.transpose(1, 0, 2) # (H, max_len, Dh)
V = self.v_cache.value.transpose(1, 0, 2)
scores = jnp.einsum('hd,hkd->hk', q, K) / jnp.sqrt(Dh)
scores = jnp.where(valid, scores, -1e9)
weights = jax.nn.softmax(scores, axis=-1)
per_head = jnp.einsum('hk,hkd->hd', weights, V)
return self.out_proj(per_head.reshape(H * Dh))
To produce the attention output for x_new given a history, you
step the module through every history token first to populate the
cache, then call it once more on x_new:
for t in range(T_history):
_ = model.step(x_history[t])
out = model.step(x_new)
The output of the new-token step is identical (up to numerics) to
what you’d get from running causal self-attention over
concat(x_history, x_new) and reading the last row.
What nnx makes easy
The mutation self.k_cache.value = ... is just an attribute write.
No mutable=["cache"] at apply time, no return-tuple of updated
state, no separate params and cache collections. The module
is the state container.
Common pitfalls
-
Forgetting to mask invalid slots. Pre-zeroed slots still
participate in the dot product. Without
validmasking, the first step would attend tomax_len - 1zero slots, which would pull the output toward zero. -
Increment before write. Bumping
posfirst puts the new K/V in the wrong slot. Write atidx = pos.value, then increment. -
Using
at[idx].set(k)on the unwrapped buffer. Always go through.value.self.k_cache.at[idx]won’t compile. -
Wrong
validdirection.jnp.arange(max_len) <= idxis what you want (positions 0..idx, inclusive of the slot just written).< idxskips the new token,< pos.value AFTER incrementworks but is fragile to refactor.
Problem
Write mha_decode_step(seed, x_new, x_history, num_heads, d_model):
-
Define
CachedMHA(nnx.Module)with fournnx.Linearprojections, twonnx.Cachebuffers(max_len, H, head_dim), and a scalarnnx.Cacheposcounter. -
step(x_token): project, write toposslot of each cache, incrementpos, mask invalid slots, run SDPA,out_proj. -
In the entry function: build with
max_len = T_history + 1, step through every history token to populate, then call once onx_new. Return the new step’s output flattened.
Inputs:
-
seed: int (passed as float). -
x_new: 1-D(d_model,)— the single new token. -
x_history: 2-D(T_history, d_model)— prefix tokens. -
num_heads,d_model: ints (passed as floats).
Output: 1-D (d_model,) — attention output for x_new.
Hints
Sign in to attempt this problem and view the solution.