hard primitives

NNX MHA From Scratch

Why this matters

Multi-head attention is the workhorse of every Transformer. Linen provides nn.MultiHeadDotProductAttention as a one-liner โ€” convenient, but opaque. nnx flips the trade-off: there is no nnx.MultiHeadAttention wrapper. You build it yourself out of four nnx.Linear projections and the SDPA you wrote in the previous problem.

The result reads like the textbook formula. No qkv_features argument to remember, no apply-time mask keyword, no params dict โ€” the four projection layers are just attributes of an nnx.Module.

The pieces

Multi-head attention with H heads on hidden dim d_model runs H independent SDPA computations in parallel, each on a head_dim = d_model / H slice of the projected Q, K, V. The four learnable maps:

  • q_proj: d_model -> d_model โ€” projects input to query subspace.
  • k_proj: d_model -> d_model โ€” projects input to key subspace.
  • v_proj: d_model -> d_model โ€” projects input to value subspace.
  • out_proj: d_model -> d_model โ€” mixes the concatenated heads.

Each is a plain nnx.Linear(d_model, d_model, rngs=rngs).

Reshape choreography

After projection, Q/K/V have shape (T, d_model). To split across heads:

(T, d_model)
    .reshape(T, H, head_dim)              # carve out heads
    .transpose(1, 0, 2)                   # heads first: (H, T, head_dim)

Now Q is (H, T, head_dim) โ€” the leading axis is heads, so each head has its own (T, head_dim) slice. SDPA per head:

scores  = Q @ K.transpose(0, 2, 1) / sqrt(head_dim)   # (H, T, T)
weights = softmax(scores, axis=-1)                     # (H, T, T)
per_head = weights @ V                                 # (H, T, head_dim)

Then back: (H, T, head_dim) -> (T, H, head_dim) -> (T, H * head_dim) = (T, d_model). Final out_proj mixes head outputs.

Worked sketch

class MHA(nnx.Module):
    def __init__(self, d_model, num_heads, rngs):
        assert d_model % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.q_proj = nnx.Linear(d_model, d_model, rngs=rngs)
        self.k_proj = nnx.Linear(d_model, d_model, rngs=rngs)
        self.v_proj = nnx.Linear(d_model, d_model, rngs=rngs)
        self.out_proj = nnx.Linear(d_model, d_model, rngs=rngs)

    def __call__(self, x):
        T, _ = x.shape
        H, Dh = self.num_heads, self.head_dim
        q = self.q_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        k = self.k_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        v = self.v_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
        scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
        weights = jax.nn.softmax(scores, axis=-1)
        per_head = jnp.matmul(weights, v)
        concat = per_head.transpose(1, 0, 2).reshape(T, H * Dh)
        return self.out_proj(concat)

Compare to Linen, where Q/K/V projections are hidden inside one DenseGeneral (with multi-axis kernels) and the reshape is implicit. Doing it by hand here is a pedagogical win โ€” you see exactly where each axis goes.

Why softmax over the last axis?

scores is (H, T_q, T_k). We want, per head and per query, a distribution over keys. Last axis = T_k, so axis=-1 is correct. axis=0 would normalize across heads (meaningless); axis=1 would normalize across queries (also meaningless).

Common pitfalls

  • d_model not divisible by num_heads. Assert it; head_dim is an integer division.
  • Forgetting to transpose before SDPA. After reshape (T, H, Dh), Q @ K.T would compute over the wrong axis. Move heads to the front first via transpose(1, 0, 2).
  • Forgetting to transpose Kโ€™s last two axes. K.transpose(0, 2, 1) swaps T and Dh so Q @ K.T is (T, Dh) @ (Dh, T) = (T, T).
  • Skipping out_proj. Multi-head outputs need a final mix; without it, heads are independent silos.
  • num_heads, d_model arriving as float. Cast to int in the entry function.

Problem

Write mha_self(seed, x, num_heads, d_model):

  1. Define MHA(nnx.Module) with four nnx.Linear(d_model, d_model) projections and num_heads, head_dim as plain int attrs.
  2. __call__(x):
    • Project Q/K/V, reshape to (T, H, Dh), transpose to (H, T, Dh).
    • SDPA per head (matmul, scale, softmax, matmul).
    • Transpose back, reshape to (T, d_model), out_proj.
  3. Cast num_heads, d_model from float to int. Build nnx.Rngs(int(seed)).
  4. Return the output flattened: out.reshape(-1).

Inputs:

  • seed: int (passed as float).
  • x: 2-D (T, d_model).
  • num_heads, d_model: ints (passed as floats).

Output: 1-D flattened (T * d_model,).

Hints

flax nnx attention multi-head transformers

Sign in to attempt this problem and view the solution.