hard end_to_end

Train Encoder-Decoder Seq2Seq Step

Implement one full seq2seq training step with teacher forcing — the standard training procedure for the original encoder-decoder Transformer (Vaswani 2017) that powers machine translation.

Pipeline

1. Forward pass (same as encoder-decoder-forward-pass):

dec_input = tgt_ids[:, :T_tgt]       # shift: input has start token
targets   = tgt_ids[:, 1:T_tgt+1]    # shift: target is next token
logits = encoder_decoder_forward(src_ids, dec_input, ...)  # (N, T_tgt, vocab_tgt)

2. Loss — cross-entropy at every decoder position:

loss = mean CE over all (N, T_tgt) positions
dlogits = (softmax(logits) - one_hot(targets)) / (N * T_tgt)

3. Backward by hand — manual chain rule:

  • CE loss → dlogits at every (batch, decoder position)
  • LM head: dw_head = x_dec_final.T @ dlogits
  • Decoder blocks in reverse (FFN → cross-attn → causal self-attn). Cross-attn backward yields BOTH dx_dec and denc_out. Sum denc_out contributions from each decoder block.
  • Encoder blocks in reverse (FFN → bidirectional self-attn).
  • Embeddings: denc_pos_embed, ddec_pos_embed, dsrc_emb (via index_add_), dtgt_emb (via index_add_).

4. SGD update on all parameters.

Teacher forcing

During training the decoder receives the correct target as input, not its own predictions. This is called teacher forcing and makes gradient flow stable. At inference time, decoder output is fed back as the next input (autoregressive decoding).

Output

Returns a single flat tensor of all updated weights:

[src_emb_flat, tgt_emb_flat, enc_pos_embed_flat, dec_pos_embed_flat,
 enc_blocks_flat, dec_blocks_flat, w_head_flat]

References

  • Vaswani et al., “Attention Is All You Need”, NeurIPS 2017.
  • Luong et al., “Effective Approaches to Attention-based Neural Machine Translation”, EMNLP 2015.

Hints

seq2seq training teacher-forcing

Sign in to attempt this problem and view the solution.