We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX MHA From Scratch
Why this matters
Multi-head attention is the workhorse of every Transformer. Linen
provides nn.MultiHeadDotProductAttention as a one-liner โ convenient,
but opaque. nnx flips the trade-off: there is no nnx.MultiHeadAttention
wrapper. You build it yourself out of four nnx.Linear projections
and the SDPA you wrote in the previous problem.
The result reads like the textbook formula. No qkv_features argument
to remember, no apply-time mask keyword, no params dict โ the four
projection layers are just attributes of an nnx.Module.
The pieces
Multi-head attention with H heads on hidden dim d_model runs H
independent SDPA computations in parallel, each on a head_dim = d_model / H
slice of the projected Q, K, V. The four learnable maps:
-
q_proj:d_model -> d_modelโ projects input to query subspace. -
k_proj:d_model -> d_modelโ projects input to key subspace. -
v_proj:d_model -> d_modelโ projects input to value subspace. -
out_proj:d_model -> d_modelโ mixes the concatenated heads.
Each is a plain nnx.Linear(d_model, d_model, rngs=rngs).
Reshape choreography
After projection, Q/K/V have shape (T, d_model). To split across heads:
(T, d_model)
.reshape(T, H, head_dim) # carve out heads
.transpose(1, 0, 2) # heads first: (H, T, head_dim)
Now Q is (H, T, head_dim) โ the leading axis is heads, so each head
has its own (T, head_dim) slice. SDPA per head:
scores = Q @ K.transpose(0, 2, 1) / sqrt(head_dim) # (H, T, T)
weights = softmax(scores, axis=-1) # (H, T, T)
per_head = weights @ V # (H, T, head_dim)
Then back: (H, T, head_dim) -> (T, H, head_dim) -> (T, H * head_dim) = (T, d_model). Final out_proj mixes head outputs.
Worked sketch
class MHA(nnx.Module):
def __init__(self, d_model, num_heads, rngs):
assert d_model % num_heads == 0
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_proj = nnx.Linear(d_model, d_model, rngs=rngs)
self.k_proj = nnx.Linear(d_model, d_model, rngs=rngs)
self.v_proj = nnx.Linear(d_model, d_model, rngs=rngs)
self.out_proj = nnx.Linear(d_model, d_model, rngs=rngs)
def __call__(self, x):
T, _ = x.shape
H, Dh = self.num_heads, self.head_dim
q = self.q_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
k = self.k_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
v = self.v_proj(x).reshape(T, H, Dh).transpose(1, 0, 2)
scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
weights = jax.nn.softmax(scores, axis=-1)
per_head = jnp.matmul(weights, v)
concat = per_head.transpose(1, 0, 2).reshape(T, H * Dh)
return self.out_proj(concat)
Compare to Linen, where Q/K/V projections are hidden inside one
DenseGeneral (with multi-axis kernels) and the reshape is implicit.
Doing it by hand here is a pedagogical win โ you see exactly where
each axis goes.
Why softmax over the last axis?
scores is (H, T_q, T_k). We want, per head and per query, a
distribution over keys. Last axis = T_k, so axis=-1 is correct.
axis=0 would normalize across heads (meaningless); axis=1 would
normalize across queries (also meaningless).
Common pitfalls
-
d_modelnot divisible bynum_heads. Assert it;head_dimis an integer division. -
Forgetting to transpose before SDPA. After reshape
(T, H, Dh),Q @ K.Twould compute over the wrong axis. Move heads to the front first viatranspose(1, 0, 2). -
Forgetting to transpose Kโs last two axes.
K.transpose(0, 2, 1)swapsTandDhsoQ @ K.Tis(T, Dh) @ (Dh, T) = (T, T). -
Skipping
out_proj. Multi-head outputs need a final mix; without it, heads are independent silos. -
num_heads, d_modelarriving as float. Cast to int in the entry function.
Problem
Write mha_self(seed, x, num_heads, d_model):
-
Define
MHA(nnx.Module)with fournnx.Linear(d_model, d_model)projections andnum_heads,head_dimas plain int attrs. -
__call__(x):-
Project Q/K/V, reshape to
(T, H, Dh), transpose to(H, T, Dh). - SDPA per head (matmul, scale, softmax, matmul).
-
Transpose back, reshape to
(T, d_model),out_proj.
-
Project Q/K/V, reshape to
-
Cast
num_heads, d_modelfrom float to int. Buildnnx.Rngs(int(seed)). -
Return the output flattened:
out.reshape(-1).
Inputs:
-
seed: int (passed as float). -
x: 2-D(T, d_model). -
num_heads,d_model: ints (passed as floats).
Output: 1-D flattened (T * d_model,).
Hints
Sign in to attempt this problem and view the solution.