We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Cross-Attention
Why this matters
Cross-attention is the bridge between two sequences. It’s how the decoder of a T5 / BART encoder-decoder model lets every generated token look at every encoded source token. It’s how DALL-E and Stable Diffusion let image patches look at text embeddings. It’s how Perceiver IO lets a small set of latents query a much larger input.
The mechanism is a single change from self-attention: queries come
from one input (x_q), keys and values from another (x_kv). Two
sequence lengths — T_q and T_k — that need not match. Everything
else is identical.
What changes vs. self-attention
Self-attention:
q = q_proj(x); k = k_proj(x); v = v_proj(x)
Cross-attention:
q = q_proj(x_q) # queries from one source
k = k_proj(x_kv) # keys from another
v = v_proj(x_kv) # values from the same other source
K and V both come from x_kv (always together — they index the
same positions). Q comes from x_q. Sequence lengths can differ,
so:
-
qhas shape(T_q, d_model)->(H, T_q, head_dim). -
k,vhave shape(T_k, d_model)->(H, T_k, head_dim). -
scoreshas shape(H, T_q, T_k). -
weightshas shape(H, T_q, T_k). -
per_headhas shape(H, T_q, head_dim). -
Output has shape
(T_q, d_model).
The output sequence length matches T_q, NOT T_k. Cross-attention
“summarizes” the source (length T_k) into one vector per query
position (length T_q).
Worked sketch
class CrossMHA(nnx.Module):
# __init__ same four nnx.Linear projections as self-attention.
def __call__(self, x_q, x_kv):
Tq, _ = x_q.shape
Tk, _ = x_kv.shape
H, Dh = self.num_heads, self.head_dim
q = self.q_proj(x_q).reshape(Tq, H, Dh).transpose(1, 0, 2)
k = self.k_proj(x_kv).reshape(Tk, H, Dh).transpose(1, 0, 2)
v = self.v_proj(x_kv).reshape(Tk, H, Dh).transpose(1, 0, 2)
scores = jnp.matmul(q, k.transpose(0, 2, 1)) / jnp.sqrt(Dh)
weights = jax.nn.softmax(scores, axis=-1)
per_head = jnp.matmul(weights, v)
concat = per_head.transpose(1, 0, 2).reshape(Tq, H * Dh)
return self.out_proj(concat)
Two reshape sites (one per source length) instead of one. The rest is the same code path.
No causal mask in cross-attention
A causal mask in cross-attention makes no sense — there is no
sequential dependency between x_q and x_kv. (T5’s decoder uses
a causal mask for self-attention over decoder tokens, then
unmasked cross-attention into the encoder output. The cross
layer always sees the full source.)
Padding masks (skip pad tokens in the source) are common in
practice but skipped here — we’d just jnp.where mask over T_k
in the score row.
Common pitfalls
-
Using one sequence-length variable.
(T, _) = x.shapethen.reshape(T, H, Dh)won’t compile whenx_qandx_kvdiffer. Capture bothTqandTkseparately. -
Putting K/V on
x_q. Then it’s self-attention with extra steps. K and V must come fromx_kv. -
Output sequence length. Cross-attention outputs
(T_q, ...), not(T_k, ...). The query sequence drives the output length. -
Forgetting
d_modelshared. Cross-attention here assumesx_q.shape[-1] == x_kv.shape[-1] == d_model. If they differ (e.g., text→image), you’d need separate projection in/out dims.
Problem
Write mha_cross(seed, x_q, x_kv, num_heads, d_model):
-
Define
CrossMHA(nnx.Module): same fournnx.Linear(d_model, d_model)projections as in pos 32. -
__call__(x_q, x_kv): project Q fromx_q, K and V fromx_kv, reshape withTqfor Q andTkfor K/V, SDPA, concat back to(Tq, d_model),out_proj. -
Cast
num_heads, d_modelto int. Buildnnx.Rngs(int(seed)). -
Return the output flattened:
(Tq * d_model,).
Inputs:
-
seed: int (passed as float). -
x_q: 2-D(T_q, d_model). -
x_kv: 2-D(T_k, d_model).T_qandT_kmay differ. -
num_heads,d_model: ints (passed as floats).
Output: 1-D flattened (T_q * d_model,).
Hints
Sign in to attempt this problem and view the solution.