hard end_to_end

Encoder-Decoder Transformer Forward Pass

Implement the full encoder-decoder transformer forward pass — the original architecture from Vaswani et al. 2017 that powers machine translation, summarization, and other sequence-to-sequence tasks.

Architecture overview

Unlike decoder-only models (GPT) or encoder-only models (BERT), the original Transformer has two separate stacks:

  • Encoder: processes the source sequence (e.g. the French sentence) using bidirectional self-attention — every source token can attend to every other source token.
  • Decoder: generates the target sequence (e.g. the English translation) one token at a time (teacher forcing during training). Each decoder block has three sub-layers:
    1. Causal self-attention — each target token attends only to itself and earlier target tokens (lower-triangular mask).
    2. Cross-attention — queries come from the decoder hidden state; keys and values come from the encoder output. This is how the decoder “reads” the source.
    3. FFN — position-wise GELU feed-forward.

Pipeline

  1. Encoder side (enc_blocks, shape (E, 6, d, d)): x_enc = src_emb[src_ids] + enc_pos_embed For each encoder block: post-LN bidirectional MHA + post-LN FFN. Each block has 6 weight matrices: [w_q, w_k, w_v, w_o, w_mlp1, w_mlp2].

  2. Decoder side (dec_blocks, shape (D, 12, d, d)): x_dec = tgt_emb[tgt_ids] + dec_pos_embed For each decoder block apply three post-LN sub-layers in order:

    • Slots 0–3: causal self-attn [w_q_self, w_k_self, w_v_self, w_o_self]
    • Slots 4–7: cross-attn [w_q_cross, w_k_cross, w_v_cross, w_o_cross]
    • Slots 8–9: FFN [w_mlp1, w_mlp2]
    • Slots 10–11: zero-padded / unused (reserved for future extensions; ignore them — shape is kept uniform across all blocks).
  3. LM head: logits = x_dec @ w_head, shape (N, T_tgt, vocab_tgt).

Post-LN convention

Every sub-layer uses x = LN(x + sub_out). Apply LN twice per encoder block and three times per decoder block.

  • LN eps = 1e-5, no learned γ/β.
  • GELU: 0.5 * t * (1 + tanh(sqrt(2/π) * (t + 0.044715 * t³))).
  • d_ff = d_model — all weight matrices are square.

Reference

Vaswani et al. “Attention Is All You Need”, NeurIPS 2017.

Hints

seq2seq transformer encoder-decoder

Sign in to attempt this problem and view the solution.