We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Scaled Dot-Product Attention
Why this matters
Scaled dot-product attention (SDPA) is THE building block of every modern Transformer. Encoder, decoder, cross-attention, multi-head — they’re all just SDPA with the right inputs and reshapes.
The formula is famous:
Attention(Q, K, V) = softmax(Q @ Kᵀ / √d) @ V
Three points to internalize:
-
Why divide by √d? Without it, the dot products grow with
d, the softmax saturates (one entry → 1, others → 0), and gradients vanish. Dividing by √d keeps the variance of the scores ≈ 1 regardless ofd. - What does the softmax do? It turns each query’s similarity-to-keys vector into a probability distribution: “spend ~70% of attention on key 0, ~20% on key 1, …”. The output is a weighted average of values.
- Why “self-attention”? Self-attention means Q, K, V all come from the same input (different linear projections). Cross-attention means Q is from one source, K and V from another.
Shape conventions
For UN-BATCHED, UN-HEADED inputs:
-
Q: shape(T_q, d)—T_qquery positions,ddimensions per query. -
K: shape(T_k, d)—T_kkey positions, samed(must match Q). -
V: shape(T_k, d_v)—T_kvalue positions (same as K’s),d_voutput dim per value (oftend_v == dbut not required). -
Output: shape
(T_q, d_v)— same query positions, value-shaped output.
Q @ Kᵀ has shape (T_q, T_k) — one row per query, columns are keys.
Worked numerical example
Identity Q vs identity K:
Q = jnp.eye(3) # shape (3, 3)
K = jnp.eye(3)
V = jnp.array([[1.0, 0.], [0., 1.], [1., 1.]]) # (3, 2)
scores = Q @ K.T / sqrt(3)
# = (1/sqrt(3)) * I_3 — diagonal of 1/sqrt(3), zeros elsewhere
weights = softmax(scores, axis=-1)
# row 0: softmax([1/√3, 0, 0]) ≈ [0.45, 0.275, 0.275]
# rows 1, 2 by symmetry.
out = weights @ V
# row 0: 0.45*[1,0] + 0.275*[0,1] + 0.275*[1,1] ≈ [0.725, 0.55]
Why axis=-1 for softmax?
scores has shape (T_q, T_k). We want each query to spread its attention
over all keys — that’s softmax along the last axis (the key axis). If
you softmax along axis 0, you’d be normalizing across queries, which is
nonsense.
Common pitfalls
-
Q @ Kinstead ofQ @ K.T: shapes won’t match unlessd == T_k. -
/ dinstead of/ √d: a real bug seen in implementations. Halves the scores too aggressively. -
softmax(axis=0): wrong axis. Always last. -
Forgetting to apply softmax: just
(Q @ K.T / √d) @ Vis NOT attention — it’s a linear function. The non-linearity is the softmax.
Problem
Implement scaled_dot_product(q, k, v) returning the SDPA output.
-
d = q.shape[-1]. -
scores = q @ k.T / sqrt(d). -
weights = jax.nn.softmax(scores, axis=-1). -
Return
weights @ v.
No Flax Module needed — this is the pure-JAX building block. Subsequent
problems will package this into nn.MultiHeadDotProductAttention.
Inputs:
-
q: 2-D(T_q, d). -
k: 2-D(T_k, d). -
v: 2-D(T_k, d_v).
Output: 2-D (T_q, d_v).
Hints
Sign in to attempt this problem and view the solution.