medium primitives

Multi-Head Self-Attention with Flax

Why this matters

Single-head attention asks one question per query: “which key is most relevant?” Multi-head attention asks H questions in parallel — each head learns to focus on a different relationship (syntactic, semantic, positional, …). Concatenate the H per-head outputs, project, done.

Doing this by hand means: split D into H chunks of D/H, run SDPA on each chunk, concatenate. Flax packages this as nn.MultiHeadDotProductAttention — you give it num_heads and qkv_features (the TOTAL Q/K/V dim across all heads), it slices.

The Flax API

attn = nn.MultiHeadDotProductAttention(num_heads=H, qkv_features=D)
  • num_heads=H: number of parallel attention heads.
  • qkv_features=D: total Q, K, V projection dim. Per-head dim = D / H. MUST be divisible by H.
  • out_features (optional): output dim after the final projection. Defaults to the input’s last dim.

Call as attn(x) for self-attention (Q, K, V all from x) or attn(x_q, x_kv) for cross-attention.

Internally, Flax reshapes Q/K/V to (..., T, H, D/H), runs SDPA per head, concatenates, and applies a final linear projection back to out_features.

What does H control?

With D=8, H=2: each head sees 4-dim Q/K/V. Two heads can specialize. With D=8, H=4: each head sees 2-dim Q/K/V. Four narrower specialists. With D=8, H=8: per-head dim is 1 — basically degenerate.

Real Transformers usually pick D/H ∈ {32, 64, 128} per head. GPT-2 base uses D=768, H=12 (per-head 64).

Worked example

x = jnp.ones((4, 8))                                # (T=4, D=8)
attn = nn.MultiHeadDotProductAttention(num_heads=2, qkv_features=8)
params = attn.init(jax.random.PRNGKey(0), x)
out = attn.apply(params, x)                         # (4, 8)

Note: 2-D input (T, D) works — Flax broadcasts over the missing batch axis. 3-D (B, T, D) is the more common production shape.

Common pitfalls

  • D not divisible by H: Flax raises. Check before constructing.
  • Forgetting to init: attn(x) doesn’t work outside init/apply — Flax modules are state-less; params live in the params dict.
  • Wrong qkv_features: this is the TOTAL dim across all heads, not per-head. Per-head dim is implicit (qkv_features // num_heads).

Problem

Implement mha_self_attention(seed, x, num_heads, qkv_features):

  1. Cast num_heads and qkv_features to int.
  2. Build nn.MultiHeadDotProductAttention(num_heads=H, qkv_features=D).
  3. Init with jax.random.PRNGKey(seed) and apply on x.
  4. Return the output flattened to 1-D via .reshape(-1).

Inputs:

  • seed: int.
  • x: 2-D (T, D_in) input.
  • num_heads: int H.
  • qkv_features: int D, divisible by H.

Output: 1-D array, the flattened (T, D_in) attention output.

Hints

flax attention multi-head transformers

Sign in to attempt this problem and view the solution.