medium primitives

Causal Multi-Head Self-Attention

Why this matters

Decoders (GPT, LLaMA, all autoregressive LMs) MUST not let position t look at positions > t. Without a mask, training would leak future tokens into the prediction at every position — the model would learn to “predict” the token it’s literally being given. Useless at inference time.

The fix is a causal mask: a lower-triangular matrix that zeros out the upper-right of the score matrix. Position i only attends to positions j ≤ i.

Building the mask

T = x.shape[0]
mask = jnp.tril(jnp.ones((T, T)))
# T=4 looks like:
# [[1, 0, 0, 0],
#  [1, 1, 0, 0],
#  [1, 1, 1, 0],
#  [1, 1, 1, 1]]

Row i says: position i may attend to columns where mask is 1 (positions 0..i), forbidden where mask is 0.

How Flax interprets the mask

nn.MultiHeadDotProductAttention accepts a mask= kwarg. Internally it adds (mask - 1) * 1e9 (or similar) to the scores BEFORE softmax, so masked-out positions get -inf-ish scores → softmax weight ≈ 0.

Mask shapes Flax accepts:

  • (T, T) — broadcast over batch and heads.
  • (B, 1, T_q, T_k) — per-batch, broadcast over heads.
  • (B, H, T_q, T_k) — fully specified.

For this problem the simple (T, T) form is enough.

Worked walk-through

Without a mask, position 0 attends to positions 0, 1, 2, 3 — including positions that don’t yet exist at inference time (token 1 hasn’t been generated when we’re predicting token 0’s next-token).

With the lower-triangular mask, position 0’s softmax row becomes [1.0, 0, 0, 0] (it can only see itself), position 1’s becomes [w₀, w₁, 0, 0], and so on.

Why this is also the inference-time mask

During autoregressive generation, you generate one token, append, run the model again. Even though you re-process all tokens each step (or use a KV cache), the causal mask ensures past predictions stay consistent with single-step generation: each position’s representation never depends on later tokens.

Common pitfalls

  • Inverted mask (triu instead of tril): allows attending to future, blocks past — the opposite of what you want.
  • Wrong shape(1, T, T) works in some libraries, but Flax wants either 2-D (T, T) or 4-D (B, H, T_q, T_k). 3-D fails silently.
  • Boolean mask vs float mask: Flax accepts both. With float, 1.0 means “allowed”, 0.0 means “blocked”.

Problem

Implement mha_causal(seed, x, num_heads, qkv_features):

  1. T = x.shape[0]. Build mask = jnp.tril(jnp.ones((T, T))).
  2. Build nn.MultiHeadDotProductAttention(num_heads=H, qkv_features=D).
  3. Init with attn.init(rng, x, mask=mask).
  4. Apply with attn.apply(params, x, mask=mask).
  5. Return out.reshape(-1).

The mask is the only difference from pos 22 — same shapes, same call pattern.

Inputs:

  • seed: int.
  • x: 2-D (T, D_in).
  • num_heads: int H.
  • qkv_features: int D.

Output: 1-D, the flattened attention output.

Hints

flax attention causal transformers

Sign in to attempt this problem and view the solution.