We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Multi-Head Self-Attention with Flax
Why this matters
Single-head attention asks one question per query: “which key is most
relevant?” Multi-head attention asks H questions in parallel — each
head learns to focus on a different relationship (syntactic, semantic,
positional, …). Concatenate the H per-head outputs, project, done.
Doing this by hand means: split D into H chunks of D/H, run SDPA
on each chunk, concatenate. Flax packages this as
nn.MultiHeadDotProductAttention — you give it num_heads and
qkv_features (the TOTAL Q/K/V dim across all heads), it slices.
The Flax API
attn = nn.MultiHeadDotProductAttention(num_heads=H, qkv_features=D)
-
num_heads=H: number of parallel attention heads. -
qkv_features=D: total Q, K, V projection dim. Per-head dim = D / H. MUST be divisible byH. -
out_features(optional): output dim after the final projection. Defaults to the input’s last dim.
Call as attn(x) for self-attention (Q, K, V all from x) or
attn(x_q, x_kv) for cross-attention.
Internally, Flax reshapes Q/K/V to (..., T, H, D/H), runs SDPA per
head, concatenates, and applies a final linear projection back to
out_features.
What does H control?
With D=8, H=2: each head sees 4-dim Q/K/V. Two heads can specialize.
With D=8, H=4: each head sees 2-dim Q/K/V. Four narrower specialists.
With D=8, H=8: per-head dim is 1 — basically degenerate.
Real Transformers usually pick D/H ∈ {32, 64, 128} per head. GPT-2
base uses D=768, H=12 (per-head 64).
Worked example
x = jnp.ones((4, 8)) # (T=4, D=8)
attn = nn.MultiHeadDotProductAttention(num_heads=2, qkv_features=8)
params = attn.init(jax.random.PRNGKey(0), x)
out = attn.apply(params, x) # (4, 8)
Note: 2-D input (T, D) works — Flax broadcasts over the missing batch
axis. 3-D (B, T, D) is the more common production shape.
Common pitfalls
-
Dnot divisible byH: Flax raises. Check before constructing. -
Forgetting to init:
attn(x)doesn’t work outsideinit/apply— Flax modules are state-less; params live in theparamsdict. -
Wrong qkv_features: this is the TOTAL dim across all heads, not
per-head. Per-head dim is implicit (
qkv_features // num_heads).
Problem
Implement mha_self_attention(seed, x, num_heads, qkv_features):
-
Cast
num_headsandqkv_featurestoint. -
Build
nn.MultiHeadDotProductAttention(num_heads=H, qkv_features=D). -
Init with
jax.random.PRNGKey(seed)and apply onx. -
Return the output flattened to 1-D via
.reshape(-1).
Inputs:
-
seed: int. -
x: 2-D(T, D_in)input. -
num_heads: int H. -
qkv_features: int D, divisible by H.
Output: 1-D array, the flattened (T, D_in) attention output.
Hints
Sign in to attempt this problem and view the solution.