We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Causal MHA
Why this matters
Decoder-only language models — GPT, LLaMA, Mistral — are autoregressive:
token t depends only on tokens <= t. The mechanism that enforces
that constraint inside the attention layer is the causal mask.
Without it, training would let the model peek at future tokens and
teach itself to copy — useless at generation time.
A causal mask is one line of code on top of the previous problem. The point isn’t the line; it’s understanding why and where in the SDPA pipeline the mask goes.
Where the mask lives
Inside SDPA the score matrix is (T, T) (or (H, T, T) per head).
Entry (i, j) measures how strongly query position i attends to
key position j. To prevent i from seeing j > i, we set those
entries to a very negative number BEFORE softmax — so after softmax
they round to zero.
causal = jnp.tril(jnp.ones((T, T))) # 1 on/below diagonal, 0 above
scores = jnp.where(causal == 0, -1e9, scores)
weights = jax.nn.softmax(scores, axis=-1)
jnp.tril(jnp.ones(...)) is the classic idiom: tril keeps the
lower triangle (including diagonal). Position i can attend to
positions 0..i (mask=1) and is blocked from i+1..T-1 (mask=0).
Why -1e9, not -inf?
softmax exponentiates: exp(-inf) = 0, but inf propagates
through other ops as NaN under autodiff (gradients of 0 * inf are
undefined). In practice every framework uses a large finite negative
number — -1e9 is the convention. After exponentiation it’s small
enough that the corresponding softmax weight is < 1e-300 — for all
practical purposes zero — but autodiff stays well-defined.
What the mask DOESN’T do
The causal mask is purely a forward-pass mechanism. It doesn’t
affect parameters, gradients, or any state. Same q_proj/k_proj/v_proj/ out_proj as before; just one extra line in the score-pipeline.
During training you compute attention over the whole sequence in parallel, with the mask blocking forward leaks per row. During inference you typically don’t even need the mask — KV-cached decoding only ever attends from the new token to the past, so the mask is implicit in the cache structure.
Worked sketch
class CausalMHA(nnx.Module):
# __init__ identical to the previous problem.
def __call__(self, x):
T, _ = x.shape
H, Dh = self.num_heads, self.head_dim
q = self.q_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
k = self.k_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
v = self.v_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
causal = jnp.tril(jnp.ones((T, T)))
scores = jnp.where(causal == 0, -1e9, scores)
weights = jax.nn.softmax(scores, axis=-1)
per_head = jnp.matmul(weights, v)
concat = per_head.transpose(1, 0, 2).reshape(T, H * Dh)
return self.out_proj(concat)
The mask broadcasts over the head axis automatically — (T, T)
against (H, T, T) lines up via NumPy broadcasting rules.
Common pitfalls
- Mask after softmax. Setting weights to 0 post-softmax breaks normalization (rows no longer sum to 1). Mask BEFORE softmax.
-
Using
-jnp.inf. Produces NaN gradients on the masked positions. Use-1e9. -
triuvstril.triukeeps the upper triangle — the wrong half.trilis what you want. -
Mask applied to weights, not scores.
weights * maskzeroes the masked weights but the unmasked ones still sum to less than 1 (since softmax already normalized including the masked entries). Mask the scores. -
Off-by-one. Position
iattending toj == iis allowed;tril(ones)includes the diagonal, which is correct.
Problem
Write mha_causal(seed, x, num_heads, d_model):
-
Define
CausalMHA(nnx.Module)with the same four projections as the previous problem. -
In
__call__: compute scaled scores, buildcausal = jnp.tril(jnp.ones((T, T))), fillscoreswith-1e9wherecausal == 0, then softmax. - Return the output flattened.
Inputs:
-
seed: int (passed as float). -
x: 2-D(T, d_model). -
num_heads,d_model: ints (passed as floats).
Output: 1-D flattened.
Hints
Sign in to attempt this problem and view the solution.