medium primitives

NNX Cross-Attention

Why this matters

Cross-attention is the bridge between two sequences. It’s how the decoder of a T5 / BART encoder-decoder model lets every generated token look at every encoded source token. It’s how DALL-E and Stable Diffusion let image patches look at text embeddings. It’s how Perceiver IO lets a small set of latents query a much larger input.

The mechanism is a single change from self-attention: queries come from one input (x_q), keys and values from another (x_kv). Two sequence lengths — T_q and T_k — that need not match. Everything else is identical.

What changes vs. self-attention

Self-attention:

q = q_proj(x); k = k_proj(x); v = v_proj(x)

Cross-attention:

q = q_proj(x_q)              # queries from one source
k = k_proj(x_kv)             # keys from another
v = v_proj(x_kv)             # values from the same other source

K and V both come from x_kv (always together — they index the same positions). Q comes from x_q. Sequence lengths can differ, so:

  • q has shape (T_q, d_model) -> (H, T_q, head_dim).
  • k, v have shape (T_k, d_model) -> (H, T_k, head_dim).
  • scores has shape (H, T_q, T_k).
  • weights has shape (H, T_q, T_k).
  • per_head has shape (H, T_q, head_dim).
  • Output has shape (T_q, d_model).

The output sequence length matches T_q, NOT T_k. Cross-attention “summarizes” the source (length T_k) into one vector per query position (length T_q).

Worked sketch

class CrossMHA(nnx.Module):
    # __init__ same four nnx.Linear projections as self-attention.

    def __call__(self, x_q, x_kv):
        Tq, _ = x_q.shape
        Tk, _ = x_kv.shape
        H, Dh = self.num_heads, self.head_dim
        q = self.q_proj(x_q).reshape(Tq, H, Dh).transpose(1, 0, 2)
        k = self.k_proj(x_kv).reshape(Tk, H, Dh).transpose(1, 0, 2)
        v = self.v_proj(x_kv).reshape(Tk, H, Dh).transpose(1, 0, 2)
        scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
        weights = jax.nn.softmax(scores, axis=-1)
        per_head = jnp.matmul(weights, v)
        concat = per_head.transpose(1, 0, 2).reshape(Tq, H * Dh)
        return self.out_proj(concat)

Two reshape sites (one per source length) instead of one. The rest is the same code path.

No causal mask in cross-attention

A causal mask in cross-attention makes no sense — there is no sequential dependency between x_q and x_kv. (T5’s decoder uses a causal mask for self-attention over decoder tokens, then unmasked cross-attention into the encoder output. The cross layer always sees the full source.)

Padding masks (skip pad tokens in the source) are common in practice but skipped here — we’d just jnp.where mask over T_k in the score row.

Common pitfalls

  • Using one sequence-length variable. (T, _) = x.shape then .reshape(T, H, Dh) won’t compile when x_q and x_kv differ. Capture both Tq and Tk separately.
  • Putting K/V on x_q. Then it’s self-attention with extra steps. K and V must come from x_kv.
  • Output sequence length. Cross-attention outputs (T_q, ...), not (T_k, ...). The query sequence drives the output length.
  • Forgetting d_model shared. Cross-attention here assumes x_q.shape[-1] == x_kv.shape[-1] == d_model. If they differ (e.g., text→image), you’d need separate projection in/out dims.

Problem

Write mha_cross(seed, x_q, x_kv, num_heads, d_model):

  1. Define CrossMHA(nnx.Module): same four nnx.Linear(d_model, d_model) projections as in pos 32.
  2. __call__(x_q, x_kv): project Q from x_q, K and V from x_kv, reshape with Tq for Q and Tk for K/V, SDPA, concat back to (Tq, d_model), out_proj.
  3. Cast num_heads, d_model to int. Build nnx.Rngs(int(seed)).
  4. Return the output flattened: (Tq * d_model,).

Inputs:

  • seed: int (passed as float).
  • x_q: 2-D (T_q, d_model).
  • x_kv: 2-D (T_k, d_model). T_q and T_k may differ.
  • num_heads, d_model: ints (passed as floats).

Output: 1-D flattened (T_q * d_model,).

Hints

flax nnx attention cross-attention transformers

Sign in to attempt this problem and view the solution.