We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Train Encoder-Decoder Seq2Seq Step
Implement one full seq2seq training step with teacher forcing — the standard training procedure for the original encoder-decoder Transformer (Vaswani 2017) that powers machine translation.
Pipeline
1. Forward pass (same as encoder-decoder-forward-pass):
dec_input = tgt_ids[:, :T_tgt] # shift: input has start token
targets = tgt_ids[:, 1:T_tgt+1] # shift: target is next token
logits = encoder_decoder_forward(src_ids, dec_input, ...) # (N, T_tgt, vocab_tgt)
2. Loss — cross-entropy at every decoder position:
loss = mean CE over all (N, T_tgt) positions
dlogits = (softmax(logits) - one_hot(targets)) / (N * T_tgt)
3. Backward by hand — manual chain rule:
- CE loss → dlogits at every (batch, decoder position)
-
LM head:
dw_head = x_dec_final.T @ dlogits -
Decoder blocks in reverse (FFN → cross-attn → causal self-attn).
Cross-attn backward yields BOTH
dx_decanddenc_out. Sumdenc_outcontributions from each decoder block. - Encoder blocks in reverse (FFN → bidirectional self-attn).
-
Embeddings:
denc_pos_embed,ddec_pos_embed,dsrc_emb(viaindex_add_),dtgt_emb(viaindex_add_).
4. SGD update on all parameters.
Teacher forcing
During training the decoder receives the correct target as input, not its own predictions. This is called teacher forcing and makes gradient flow stable. At inference time, decoder output is fed back as the next input (autoregressive decoding).
Output
Returns a single flat tensor of all updated weights:
[src_emb_flat, tgt_emb_flat, enc_pos_embed_flat, dec_pos_embed_flat,
enc_blocks_flat, dec_blocks_flat, w_head_flat]
References
- Vaswani et al., “Attention Is All You Need”, NeurIPS 2017.
- Luong et al., “Effective Approaches to Attention-based Neural Machine Translation”, EMNLP 2015.
Hints
Sign in to attempt this problem and view the solution.