medium end_to_end

Cross-Attention Block

Implement a multi-head cross-attention block β€” the encoder-decoder fusion point of sequence-to-sequence transformers.

What is cross-attention?

In self-attention (e.g. multi-head-attention-block) every token attends to every other token in the same sequence. In cross-attention the queries come from one sequence (the decoder hidden states) while the keys and values come from a different sequence (the encoder output). This is how the decoder β€œreads” the encoder at every layer.

Where it appears:

  • Original Transformer (Vaswani 2017) for machine translation β€” the decoder attends to the encoder’s final hidden states.
  • Modern multimodal models (e.g. Stable Diffusion) where image or text conditioning is injected via cross-attention.
  • Perceiver IO β€” cross-attention from a small latent array to a large input.

Pipeline (post-LN, GELU FFN)

Given decoder_x of shape (N, T_dec, d_model) and encoder_out of shape (N, T_enc, d_model):

  1. Q from decoder, K/V from encoder: Q = decoder_x @ w_q β€” shape (N, T_dec, d_model). K = encoder_out @ w_k β€” shape (N, T_enc, d_model). V = encoder_out @ w_v β€” shape (N, T_enc, d_model).
  2. Reshape to multi-head: split last dim into (num_heads, d_head), transpose to (N, num_heads, T_*, d_head).
  3. Scaled dot-product (NO causal mask β€” full attention over T_enc): scores = Q @ Kα΅€ / sqrt(d_head) β€” shape (N, num_heads, T_dec, T_enc). attn = softmax(scores, dim=-1). per_head = attn @ V β€” shape (N, num_heads, T_dec, d_head).
  4. Concat heads β†’ (N, T_dec, d_model). attn_out = concat @ w_o.
  5. Post-LN residual (on decoder_x): x = LayerNorm(decoder_x + attn_out).
  6. GELU FFN: ffn_out = GELU(x @ w_mlp1) @ w_mlp2.
  7. Post-LN residual: x = LayerNorm(x + ffn_out).
  8. Return x β€” shape (N, T_dec, d_model).

Note: the residual is always over decoder_x (and then over x after the FFN) β€” not over encoder_out. The encoder output contributes only through K and V.

Hyperparameters

  • LN: eps = 1e-5, no learned Ξ³/Ξ².
  • GELU: standard approximate formula (0.5 * t * (1 + tanh(sqrt(2/Ο€) * (t + 0.044715 * tΒ³)))).
  • d_ff = d_model (square weight matrices for w_mlp1, w_mlp2).

Inputs / Output

  • decoder_x: shape (N, T_dec, d_model).
  • encoder_out: shape (N, T_enc, d_model).
  • w_q, w_k, w_v, w_o, w_mlp1, w_mlp2: shape (d_model, d_model).
  • num_heads: int β€” must divide d_model evenly.
  • Output: shape (N, T_dec, d_model).

Hints

attention cross-attention transformer

Sign in to attempt this problem and view the solution.