medium primitives

Cross-Attention with Flax MHA

Why this matters

Self-attention: Q, K, V all from the same input. Each token attends to other tokens in the SAME sequence. Used inside encoder layers and inside decoder layers.

Cross-attention: Q from one source, K and V from another. The decoder uses cross-attention to read from the encoder’s output: “given what I’m generating now (Q), which encoder positions (K, V) matter?”

Original Transformer (Vaswani et al., 2017): the encoder→decoder bridge is exactly this. Diffusion models with text conditioning, Perceiver-style architectures, retrieval-augmented LMs — all rely on cross-attention.

The two-input call

nn.MultiHeadDotProductAttention accepts two positional args:

attn = nn.MultiHeadDotProductAttention(num_heads=H, qkv_features=D)
out  = attn(inputs_q, inputs_kv)
  • inputs_q: shape (T_q, D_in). Source for Q.
  • inputs_kv: shape (T_kv, D_in). Source for both K and V.

T_q and T_kv can differ — that’s the whole point. Encoder length and decoder length are usually different.

Self-attention is the special case attn(x) (or attn(x, x)).

Output shape

The output has shape (T_q, D_out) — same query length, value-shaped output. K and V both have length T_kv, so the score matrix is (T_q, T_kv), and after attending to V you get T_q rows.

Worked example: encoder→decoder

# Encoder output: (T_kv=8 source tokens, D=64).
# Decoder hidden:  (T_q=4 generated tokens, D=64).
cross = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=64)
params = cross.init(rng, decoder_h, encoder_out)
out = cross.apply(params, decoder_h, encoder_out)   # (4, 64)

Each of the 4 decoder positions produces one 64-dim vector that summarizes “which 8 encoder positions matter to me, weighted”.

Common pitfalls

  • Calling attn(x_kv, x_q) — first arg is Q-source. Order matters.
  • Shape mismatch on D_inx_q and x_kv can have different T, but their LAST dim feeds the same Q/K/V projections, so they must share the input feature dim. (Or: separate projections for Q vs K/V, but Flax’s MHA uses a single D_in by default.)
  • Confusing T_q and T_kv — the score matrix is (T_q, T_kv), not square. Softmax is along the last axis (the K axis), as always.

Problem

Implement mha_cross(seed, x_q, x_kv, num_heads, qkv_features):

  1. Build nn.MultiHeadDotProductAttention(num_heads=H, qkv_features=D).
  2. Init with attn.init(rng, x_q, x_kv).
  3. Apply with attn.apply(params, x_q, x_kv).
  4. Return out.reshape(-1).

Inputs:

  • seed: int.
  • x_q: 2-D (T_q, D_in). Query source.
  • x_kv: 2-D (T_kv, D_in). Key/value source.
  • num_heads, qkv_features: ints.

Output: 1-D, the flattened (T_q, D_in) cross-attention output.

Hints

flax attention cross-attention transformers

Sign in to attempt this problem and view the solution.