medium primitives

NNX Scaled Dot-Product Attention

Why this matters

Scaled dot-product attention (SDPA) is the inner loop of every modern Transformer. Multi-head attention is just SDPA repeated H times in parallel with sliced projections, then concatenated. Causal LMs are SDPA with a triangular mask. Cross-attention is SDPA where Q comes from one sequence and K/V from another. KV caching is SDPA over a growing K/V buffer. Get this one function right and the rest of the track is bookkeeping.

This problem deliberately involves NO nnx.Module. SDPA is pure math over arrays — there is nothing to parameterize. Establishing the math cleanly here lets the next nine problems focus on the module structure around the math.

The formula

Given queries Q of shape (T_q, d_k), keys K of shape (T_k, d_k), and values V of shape (T_k, d_v):

Attention(Q, K, V) = softmax(Q K^T / sqrt(d_k)) @ V

Three steps:

  1. Score: Q @ K^T — every query dotted with every key. Shape (T_q, T_k). Entry (i, j) measures how much query i “wants” key j.
  2. Scale: divide by sqrt(d_k). Without this scale, large d_k drives the dot products to large magnitudes and softmax saturates to a one-hot — gradient dies. The scale keeps the score variance roughly 1 regardless of d_k.
  3. Softmax + value: row-wise softmax converts scores to a probability distribution over keys; matmul with V produces a weighted sum of values per query. Shape (T_q, d_v).

T_q and T_k need not match — that’s what makes cross-attention work. d_k (key/query dim) and d_v (value dim) need not match either, though in MHA we usually pick them equal for tidy reshapes.

Why softmax over the LAST axis?

Softmax normalizes along the dimension you want to be a probability distribution. We want every query to distribute its attention over keys — so for each query (axis 0), softmax across keys (axis -1, which is T_k). After softmax, every row sums to 1.

A common bug: softmax(scores, axis=0). That makes every column sum to 1, so each key divvies its mass over queries — meaningless.

Worked example

Q = jnp.eye(3, 4)                    # 3 queries, d_k=4
K = jnp.eye(4)                       # 4 keys
V = jnp.arange(16).reshape(4, 4) + 1.0
out = sdpa(Q, K, V)                  # shape (3, 4)

Each query (a one-hot row) maximally aligns with one key (Q @ K.T = I), so softmax sharply prefers the matching value row — but 1/sqrt(4) = 0.5 softens the peak, so the output is a mixture of all four V rows with one row dominating.

What this is NOT yet

  • No mask. A causal LM would zero out future positions before softmax. We’ll add that in pos 33.
  • No multi-head. A single Q/K/V here. MHA in pos 32 splits the hidden dim into H heads and runs SDPA per head.
  • No projections. Real attention has W_Q, W_K, W_V learned linear maps. We’re operating on already-projected Q, K, V.

Common pitfalls

  • Forgetting the scale. softmax(Q @ K.T) without /sqrt(d_k) works for tiny d_k but explodes for d_k = 64 or larger.
  • Wrong softmax axis. axis=-1 (over keys), not axis=0.
  • d_k from k.shape[-1] vs q.shape[-1]. They must match (else Q @ K.T is undefined). Either is fine.
  • Using jnp.exp + manual normalize. jax.nn.softmax is more numerically stable (subtracts the row-max before exponentiating).
  • @ vs jnp.dot. Equivalent for 2-D. The @ operator is more readable and matches the math.

Problem

Write sdpa(q, k, v):

  1. Compute scores = q @ k.T / jnp.sqrt(d_k), where d_k = q.shape[-1].
  2. Apply jax.nn.softmax(scores, axis=-1) to get attention weights.
  3. Return weights @ v — the value-weighted output.

No nnx module here — just three array operations.

Inputs:

  • q: 2-D (T_q, d_k).
  • k: 2-D (T_k, d_k).
  • v: 2-D (T_k, d_v).

Output: 2-D (T_q, d_v).

Hints

flax nnx attention sdpa transformers

Sign in to attempt this problem and view the solution.