We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Scaled Dot-Product Attention
Why this matters
Scaled dot-product attention (SDPA) is the inner loop of every modern
Transformer. Multi-head attention is just SDPA repeated H times in
parallel with sliced projections, then concatenated. Causal LMs are SDPA
with a triangular mask. Cross-attention is SDPA where Q comes from one
sequence and K/V from another. KV caching is SDPA over a growing K/V
buffer. Get this one function right and the rest of the track is
bookkeeping.
This problem deliberately involves NO nnx.Module. SDPA is pure math
over arrays — there is nothing to parameterize. Establishing the math
cleanly here lets the next nine problems focus on the module structure
around the math.
The formula
Given queries Q of shape (T_q, d_k), keys K of shape (T_k, d_k),
and values V of shape (T_k, d_v):
Attention(Q, K, V) = softmax(Q K^T / sqrt(d_k)) @ V
Three steps:
-
Score:
Q @ K^T— every query dotted with every key. Shape(T_q, T_k). Entry(i, j)measures how much queryi“wants” keyj. -
Scale: divide by
sqrt(d_k). Without this scale, larged_kdrives the dot products to large magnitudes and softmax saturates to a one-hot — gradient dies. The scale keeps the score variance roughly 1 regardless ofd_k. -
Softmax + value: row-wise softmax converts scores to a
probability distribution over keys; matmul with
Vproduces a weighted sum of values per query. Shape(T_q, d_v).
T_q and T_k need not match — that’s what makes cross-attention work.
d_k (key/query dim) and d_v (value dim) need not match either, though
in MHA we usually pick them equal for tidy reshapes.
Why softmax over the LAST axis?
Softmax normalizes along the dimension you want to be a probability
distribution. We want every query to distribute its attention over keys —
so for each query (axis 0), softmax across keys (axis -1, which is
T_k). After softmax, every row sums to 1.
A common bug: softmax(scores, axis=0). That makes every column
sum to 1, so each key divvies its mass over queries — meaningless.
Worked example
Q = jnp.eye(3, 4) # 3 queries, d_k=4
K = jnp.eye(4) # 4 keys
V = jnp.arange(16).reshape(4, 4) + 1.0
out = sdpa(Q, K, V) # shape (3, 4)
Each query (a one-hot row) maximally aligns with one key (Q @ K.T = I),
so softmax sharply prefers the matching value row — but 1/sqrt(4) = 0.5
softens the peak, so the output is a mixture of all four V rows with
one row dominating.
What this is NOT yet
- No mask. A causal LM would zero out future positions before softmax. We’ll add that in pos 33.
-
No multi-head. A single Q/K/V here. MHA in pos 32 splits the
hidden dim into
Hheads and runs SDPA per head. -
No projections. Real attention has
W_Q, W_K, W_Vlearned linear maps. We’re operating on already-projected Q, K, V.
Common pitfalls
-
Forgetting the scale.
softmax(Q @ K.T)without/sqrt(d_k)works for tinyd_kbut explodes ford_k = 64or larger. -
Wrong softmax axis.
axis=-1(over keys), notaxis=0. -
d_kfromk.shape[-1]vsq.shape[-1]. They must match (elseQ @ K.Tis undefined). Either is fine. -
Using
jnp.exp+ manual normalize.jax.nn.softmaxis more numerically stable (subtracts the row-max before exponentiating). -
@vsjnp.dot. Equivalent for 2-D. The@operator is more readable and matches the math.
Problem
Write sdpa(q, k, v):
-
Compute
scores = q @ k.T / jnp.sqrt(d_k), whered_k = q.shape[-1]. -
Apply
jax.nn.softmax(scores, axis=-1)to get attention weights. -
Return
weights @ v— the value-weighted output.
No nnx module here — just three array operations.
Inputs:
-
q: 2-D(T_q, d_k). -
k: 2-D(T_k, d_k). -
v: 2-D(T_k, d_v).
Output: 2-D (T_q, d_v).
Hints
Sign in to attempt this problem and view the solution.