We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Train RNN Language Model
Implement one training step of an RNN language model โ from embedding
lookup all the way through BPTT (backpropagation through time) and an SGD
parameter update. No nn.RNN, no autograd. Every gradient by hand.
The model
Given a token sequence of length T, the first T-1 tokens are inputs
and the last T-1 tokens are targets (next-token prediction).
Forward pass (for t = 0 โฆ T-2):
x[t] = w_emb[tokens[t]] # embedding lookup
h[t] = tanh(concat(x[t], h[t-1]) @ w_h) # RNN cell
logits[t] = h[t] @ w_out # project to vocab
probs[t] = softmax(logits[t]) # over vocab V
Loss: mean cross-entropy over the n_steps = T-1 predictions.
loss = -mean_t( log probs[t][tokens[t+1]] )
BPTT โ the backward pass
Walk backward through time, accumulating dh_next from future steps:
dlogits[t] = (probs[t] - one_hot(tokens[t+1])) / n_steps
dw_out += outer(h[t], dlogits[t]) (accumulated)
dh = w_out @ dlogits[t] + dh_next
d_pre_act = dh * (1 - h[t]**2) (tanh derivative)
dw_h += outer(concat(x[t], h[t-1]), d_pre_act)
d_concat = w_h @ d_pre_act (shape d_emb + d_h)
dx = d_concat[:d_emb]
dh_next = d_concat[d_emb:] (propagate to previous step)
dw_emb[tokens[t]] += dx
SGD update: each parameter p -= lr * dp.
Inputs
-
tokens: shape(T,)โ int token ids as floats; cast to int inside. -
w_emb: shape(V, d_emb)โ embedding matrix. -
w_h: shape(d_emb + d_h, d_h)โ RNN cell weights (concat-linear). -
w_out: shape(d_h, V)โ output projection (no bias). -
h0: shape(d_h,)โ initial hidden state. -
lr: float โ learning rate.
Output
Returns shape (d_h * V,) โ w_out flattened after the SGD step.
(w_emb and w_h are also updated but not asserted, keeping tests tractable.)
Notes
- Use a numerically stable softmax: subtract the row-wise max before exponentiating.
-
lr=0: no weight change;w_outis returned unchanged. -
The vanishing gradient problem: tanh derivatives lie in
(0, 1], so gradients shrink as they propagate back through many timesteps. LSTM and GRU were designed to mitigate this.
Hints
Sign in to attempt this problem and view the solution.