We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Cross-Attention with Flax MHA
Why this matters
Self-attention: Q, K, V all from the same input. Each token attends to other tokens in the SAME sequence. Used inside encoder layers and inside decoder layers.
Cross-attention: Q from one source, K and V from another. The decoder uses cross-attention to read from the encoder’s output: “given what I’m generating now (Q), which encoder positions (K, V) matter?”
Original Transformer (Vaswani et al., 2017): the encoder→decoder bridge is exactly this. Diffusion models with text conditioning, Perceiver-style architectures, retrieval-augmented LMs — all rely on cross-attention.
The two-input call
nn.MultiHeadDotProductAttention accepts two positional args:
attn = nn.MultiHeadDotProductAttention(num_heads=H, qkv_features=D)
out = attn(inputs_q, inputs_kv)
-
inputs_q: shape(T_q, D_in). Source for Q. -
inputs_kv: shape(T_kv, D_in). Source for both K and V.
T_q and T_kv can differ — that’s the whole point. Encoder length
and decoder length are usually different.
Self-attention is the special case attn(x) (or attn(x, x)).
Output shape
The output has shape (T_q, D_out) — same query length, value-shaped
output. K and V both have length T_kv, so the score matrix is
(T_q, T_kv), and after attending to V you get T_q rows.
Worked example: encoder→decoder
# Encoder output: (T_kv=8 source tokens, D=64).
# Decoder hidden: (T_q=4 generated tokens, D=64).
cross = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=64)
params = cross.init(rng, decoder_h, encoder_out)
out = cross.apply(params, decoder_h, encoder_out) # (4, 64)
Each of the 4 decoder positions produces one 64-dim vector that summarizes “which 8 encoder positions matter to me, weighted”.
Common pitfalls
-
Calling
attn(x_kv, x_q)— first arg is Q-source. Order matters. -
Shape mismatch on D_in —
x_qandx_kvcan have differentT, but their LAST dim feeds the same Q/K/V projections, so they must share the input feature dim. (Or: separate projections for Q vs K/V, but Flax’s MHA uses a singleD_inby default.) -
Confusing T_q and T_kv — the score matrix is
(T_q, T_kv), not square. Softmax is along the last axis (the K axis), as always.
Problem
Implement mha_cross(seed, x_q, x_kv, num_heads, qkv_features):
-
Build
nn.MultiHeadDotProductAttention(num_heads=H, qkv_features=D). -
Init with
attn.init(rng, x_q, x_kv). -
Apply with
attn.apply(params, x_q, x_kv). -
Return
out.reshape(-1).
Inputs:
-
seed: int. -
x_q: 2-D(T_q, D_in). Query source. -
x_kv: 2-D(T_kv, D_in). Key/value source. -
num_heads,qkv_features: ints.
Output: 1-D, the flattened (T_q, D_in) cross-attention output.
Hints
Sign in to attempt this problem and view the solution.