medium primitives

NNX Causal MHA

Why this matters

Decoder-only language models — GPT, LLaMA, Mistral — are autoregressive: token t depends only on tokens <= t. The mechanism that enforces that constraint inside the attention layer is the causal mask. Without it, training would let the model peek at future tokens and teach itself to copy — useless at generation time.

A causal mask is one line of code on top of the previous problem. The point isn’t the line; it’s understanding why and where in the SDPA pipeline the mask goes.

Where the mask lives

Inside SDPA the score matrix is (T, T) (or (H, T, T) per head). Entry (i, j) measures how strongly query position i attends to key position j. To prevent i from seeing j > i, we set those entries to a very negative number BEFORE softmax — so after softmax they round to zero.

causal = jnp.tril(jnp.ones((T, T)))   # 1 on/below diagonal, 0 above
scores = jnp.where(causal == 0, -1e9, scores)
weights = jax.nn.softmax(scores, axis=-1)

jnp.tril(jnp.ones(...)) is the classic idiom: tril keeps the lower triangle (including diagonal). Position i can attend to positions 0..i (mask=1) and is blocked from i+1..T-1 (mask=0).

Why -1e9, not -inf?

softmax exponentiates: exp(-inf) = 0, but inf propagates through other ops as NaN under autodiff (gradients of 0 * inf are undefined). In practice every framework uses a large finite negative number — -1e9 is the convention. After exponentiation it’s small enough that the corresponding softmax weight is < 1e-300 — for all practical purposes zero — but autodiff stays well-defined.

What the mask DOESN’T do

The causal mask is purely a forward-pass mechanism. It doesn’t affect parameters, gradients, or any state. Same q_proj/k_proj/v_proj/ out_proj as before; just one extra line in the score-pipeline.

During training you compute attention over the whole sequence in parallel, with the mask blocking forward leaks per row. During inference you typically don’t even need the mask — KV-cached decoding only ever attends from the new token to the past, so the mask is implicit in the cache structure.

Worked sketch

class CausalMHA(nnx.Module):
    # __init__ identical to the previous problem.

    def __call__(self, x):
        T, _ = x.shape
        H, Dh = self.num_heads, self.head_dim
        q = self.q_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        k = self.k_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        v = self.v_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
        causal = jnp.tril(jnp.ones((T, T)))
        scores = jnp.where(causal == 0, -1e9, scores)
        weights = jax.nn.softmax(scores, axis=-1)
        per_head = jnp.matmul(weights, v)
        concat = per_head.transpose(1, 0, 2).reshape(T, H * Dh)
        return self.out_proj(concat)

The mask broadcasts over the head axis automatically — (T, T) against (H, T, T) lines up via NumPy broadcasting rules.

Common pitfalls

  • Mask after softmax. Setting weights to 0 post-softmax breaks normalization (rows no longer sum to 1). Mask BEFORE softmax.
  • Using -jnp.inf. Produces NaN gradients on the masked positions. Use -1e9.
  • triu vs tril. triu keeps the upper triangle — the wrong half. tril is what you want.
  • Mask applied to weights, not scores. weights * mask zeroes the masked weights but the unmasked ones still sum to less than 1 (since softmax already normalized including the masked entries). Mask the scores.
  • Off-by-one. Position i attending to j == i is allowed; tril(ones) includes the diagonal, which is correct.

Problem

Write mha_causal(seed, x, num_heads, d_model):

  1. Define CausalMHA(nnx.Module) with the same four projections as the previous problem.
  2. In __call__: compute scaled scores, build causal = jnp.tril(jnp.ones((T, T))), fill scores with -1e9 where causal == 0, then softmax.
  3. Return the output flattened.

Inputs:

  • seed: int (passed as float).
  • x: 2-D (T, d_model).
  • num_heads, d_model: ints (passed as floats).

Output: 1-D flattened.

Hints

flax nnx attention causal transformers

Sign in to attempt this problem and view the solution.