hard end_to_end

Train RNN Language Model

Implement one training step of an RNN language model โ€” from embedding lookup all the way through BPTT (backpropagation through time) and an SGD parameter update. No nn.RNN, no autograd. Every gradient by hand.

The model

Given a token sequence of length T, the first T-1 tokens are inputs and the last T-1 tokens are targets (next-token prediction).

Forward pass (for t = 0 โ€ฆ T-2):

x[t]      = w_emb[tokens[t]]               # embedding lookup
h[t]      = tanh(concat(x[t], h[t-1]) @ w_h)   # RNN cell
logits[t] = h[t] @ w_out                   # project to vocab
probs[t]  = softmax(logits[t])             # over vocab V

Loss: mean cross-entropy over the n_steps = T-1 predictions.

loss = -mean_t( log probs[t][tokens[t+1]] )

BPTT โ€” the backward pass

Walk backward through time, accumulating dh_next from future steps:

dlogits[t] = (probs[t] - one_hot(tokens[t+1])) / n_steps
dw_out    += outer(h[t], dlogits[t])          (accumulated)
dh         = w_out @ dlogits[t] + dh_next
d_pre_act  = dh * (1 - h[t]**2)              (tanh derivative)
dw_h      += outer(concat(x[t], h[t-1]), d_pre_act)
d_concat   = w_h @ d_pre_act                 (shape d_emb + d_h)
dx         = d_concat[:d_emb]
dh_next    = d_concat[d_emb:]                (propagate to previous step)
dw_emb[tokens[t]] += dx

SGD update: each parameter p -= lr * dp.

Inputs

  • tokens: shape (T,) โ€” int token ids as floats; cast to int inside.
  • w_emb: shape (V, d_emb) โ€” embedding matrix.
  • w_h: shape (d_emb + d_h, d_h) โ€” RNN cell weights (concat-linear).
  • w_out: shape (d_h, V) โ€” output projection (no bias).
  • h0: shape (d_h,) โ€” initial hidden state.
  • lr: float โ€” learning rate.

Output

Returns shape (d_h * V,) โ€” w_out flattened after the SGD step. (w_emb and w_h are also updated but not asserted, keeping tests tractable.)

Notes

  • Use a numerically stable softmax: subtract the row-wise max before exponentiating.
  • lr=0: no weight change; w_out is returned unchanged.
  • The vanishing gradient problem: tanh derivatives lie in (0, 1], so gradients shrink as they propagate back through many timesteps. LSTM and GRU were designed to mitigate this.

Hints

rnn language-model training bptt

Sign in to attempt this problem and view the solution.