We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Cross-Attention Block
Implement a multi-head cross-attention block β the encoder-decoder fusion point of sequence-to-sequence transformers.
What is cross-attention?
In self-attention (e.g. multi-head-attention-block) every token attends to
every other token in the same sequence. In cross-attention the queries
come from one sequence (the decoder hidden states) while the keys and values
come from a different sequence (the encoder output). This is how the decoder
βreadsβ the encoder at every layer.
Where it appears:
- Original Transformer (Vaswani 2017) for machine translation β the decoder attends to the encoderβs final hidden states.
- Modern multimodal models (e.g. Stable Diffusion) where image or text conditioning is injected via cross-attention.
- Perceiver IO β cross-attention from a small latent array to a large input.
Pipeline (post-LN, GELU FFN)
Given decoder_x of shape (N, T_dec, d_model) and encoder_out of shape
(N, T_enc, d_model):
-
Q from decoder, K/V from encoder:
Q = decoder_x @ w_qβ shape(N, T_dec, d_model).K = encoder_out @ w_kβ shape(N, T_enc, d_model).V = encoder_out @ w_vβ shape(N, T_enc, d_model). -
Reshape to multi-head: split last dim into
(num_heads, d_head), transpose to(N, num_heads, T_*, d_head). -
Scaled dot-product (NO causal mask β full attention over T_enc):
scores = Q @ Kα΅ / sqrt(d_head)β shape(N, num_heads, T_dec, T_enc).attn = softmax(scores, dim=-1).per_head = attn @ Vβ shape(N, num_heads, T_dec, d_head). -
Concat heads β
(N, T_dec, d_model).attn_out = concat @ w_o. -
Post-LN residual (on decoder_x):
x = LayerNorm(decoder_x + attn_out). -
GELU FFN:
ffn_out = GELU(x @ w_mlp1) @ w_mlp2. -
Post-LN residual:
x = LayerNorm(x + ffn_out). -
Return x β shape
(N, T_dec, d_model).
Note: the residual is always over decoder_x (and then over x after the
FFN) β not over encoder_out. The encoder output contributes only through
K and V.
Hyperparameters
-
LN:
eps = 1e-5, no learned Ξ³/Ξ². -
GELU: standard approximate formula
(
0.5 * t * (1 + tanh(sqrt(2/Ο) * (t + 0.044715 * tΒ³)))). -
d_ff = d_model(square weight matrices forw_mlp1,w_mlp2).
Inputs / Output
-
decoder_x: shape(N, T_dec, d_model). -
encoder_out: shape(N, T_enc, d_model). -
w_q, w_k, w_v, w_o, w_mlp1, w_mlp2: shape(d_model, d_model). -
num_heads: int β must divided_modelevenly. -
Output: shape
(N, T_dec, d_model).
Hints
Sign in to attempt this problem and view the solution.