medium primitives

Scaled Dot-Product Attention

Why this matters

Scaled dot-product attention (SDPA) is THE building block of every modern Transformer. Encoder, decoder, cross-attention, multi-head — they’re all just SDPA with the right inputs and reshapes.

The formula is famous:

Attention(Q, K, V) = softmax(Q @ Kᵀ / √d) @ V

Three points to internalize:

  1. Why divide by √d? Without it, the dot products grow with d, the softmax saturates (one entry → 1, others → 0), and gradients vanish. Dividing by √d keeps the variance of the scores ≈ 1 regardless of d.
  2. What does the softmax do? It turns each query’s similarity-to-keys vector into a probability distribution: “spend ~70% of attention on key 0, ~20% on key 1, …”. The output is a weighted average of values.
  3. Why “self-attention”? Self-attention means Q, K, V all come from the same input (different linear projections). Cross-attention means Q is from one source, K and V from another.

Shape conventions

For UN-BATCHED, UN-HEADED inputs:

  • Q: shape (T_q, d)T_q query positions, d dimensions per query.
  • K: shape (T_k, d)T_k key positions, same d (must match Q).
  • V: shape (T_k, d_v)T_k value positions (same as K’s), d_v output dim per value (often d_v == d but not required).
  • Output: shape (T_q, d_v) — same query positions, value-shaped output.

Q @ Kᵀ has shape (T_q, T_k) — one row per query, columns are keys.

Worked numerical example

Identity Q vs identity K:

Q = jnp.eye(3)                                # shape (3, 3)
K = jnp.eye(3)
V = jnp.array([[1.0, 0.], [0., 1.], [1., 1.]])  # (3, 2)

scores = Q @ K.T / sqrt(3)
# = (1/sqrt(3)) * I_3 — diagonal of 1/sqrt(3), zeros elsewhere

weights = softmax(scores, axis=-1)
# row 0: softmax([1/√3, 0, 0]) ≈ [0.45, 0.275, 0.275]
# rows 1, 2 by symmetry.

out = weights @ V
# row 0: 0.45*[1,0] + 0.275*[0,1] + 0.275*[1,1] ≈ [0.725, 0.55]

Why axis=-1 for softmax?

scores has shape (T_q, T_k). We want each query to spread its attention over all keys — that’s softmax along the last axis (the key axis). If you softmax along axis 0, you’d be normalizing across queries, which is nonsense.

Common pitfalls

  • Q @ K instead of Q @ K.T: shapes won’t match unless d == T_k.
  • / d instead of / √d: a real bug seen in implementations. Halves the scores too aggressively.
  • softmax(axis=0): wrong axis. Always last.
  • Forgetting to apply softmax: just (Q @ K.T / √d) @ V is NOT attention — it’s a linear function. The non-linearity is the softmax.

Problem

Implement scaled_dot_product(q, k, v) returning the SDPA output.

  1. d = q.shape[-1].
  2. scores = q @ k.T / sqrt(d).
  3. weights = jax.nn.softmax(scores, axis=-1).
  4. Return weights @ v.

No Flax Module needed — this is the pure-JAX building block. Subsequent problems will package this into nn.MultiHeadDotProductAttention.

Inputs:

  • q: 2-D (T_q, d).
  • k: 2-D (T_k, d).
  • v: 2-D (T_k, d_v).

Output: 2-D (T_q, d_v).

Hints

flax attention sdpa

Sign in to attempt this problem and view the solution.