We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Causal Multi-Head Self-Attention
Why this matters
Decoders (GPT, LLaMA, all autoregressive LMs) MUST not let position t
look at positions > t. Without a mask, training would leak future
tokens into the prediction at every position — the model would learn
to “predict” the token it’s literally being given. Useless at inference
time.
The fix is a causal mask: a lower-triangular matrix that zeros out
the upper-right of the score matrix. Position i only attends to
positions j ≤ i.
Building the mask
T = x.shape[0]
mask = jnp.tril(jnp.ones((T, T)))
# T=4 looks like:
# [[1, 0, 0, 0],
# [1, 1, 0, 0],
# [1, 1, 1, 0],
# [1, 1, 1, 1]]
Row i says: position i may attend to columns where mask is 1
(positions 0..i), forbidden where mask is 0.
How Flax interprets the mask
nn.MultiHeadDotProductAttention accepts a mask= kwarg. Internally
it adds (mask - 1) * 1e9 (or similar) to the scores BEFORE softmax,
so masked-out positions get -inf-ish scores → softmax weight ≈ 0.
Mask shapes Flax accepts:
-
(T, T)— broadcast over batch and heads. -
(B, 1, T_q, T_k)— per-batch, broadcast over heads. -
(B, H, T_q, T_k)— fully specified.
For this problem the simple (T, T) form is enough.
Worked walk-through
Without a mask, position 0 attends to positions 0, 1, 2, 3 — including positions that don’t yet exist at inference time (token 1 hasn’t been generated when we’re predicting token 0’s next-token).
With the lower-triangular mask, position 0’s softmax row becomes
[1.0, 0, 0, 0] (it can only see itself), position 1’s becomes
[w₀, w₁, 0, 0], and so on.
Why this is also the inference-time mask
During autoregressive generation, you generate one token, append, run the model again. Even though you re-process all tokens each step (or use a KV cache), the causal mask ensures past predictions stay consistent with single-step generation: each position’s representation never depends on later tokens.
Common pitfalls
-
Inverted mask (
triuinstead oftril): allows attending to future, blocks past — the opposite of what you want. -
Wrong shape —
(1, T, T)works in some libraries, but Flax wants either 2-D(T, T)or 4-D(B, H, T_q, T_k). 3-D fails silently. - Boolean mask vs float mask: Flax accepts both. With float, 1.0 means “allowed”, 0.0 means “blocked”.
Problem
Implement mha_causal(seed, x, num_heads, qkv_features):
-
T = x.shape[0]. Buildmask = jnp.tril(jnp.ones((T, T))). -
Build
nn.MultiHeadDotProductAttention(num_heads=H, qkv_features=D). -
Init with
attn.init(rng, x, mask=mask). -
Apply with
attn.apply(params, x, mask=mask). -
Return
out.reshape(-1).
The mask is the only difference from pos 22 — same shapes, same call pattern.
Inputs:
-
seed: int. -
x: 2-D(T, D_in). -
num_heads: int H. -
qkv_features: int D.
Output: 1-D, the flattened attention output.
Hints
Sign in to attempt this problem and view the solution.