medium end_to_end

Simple RNN Cell

Implement one step of a vanilla RNN cell.

The RNN update equations: $$h_t = \tanh(x_t \cdot W_{xh} + h_{t-1} \cdot W_{hh} + b_h)$$

Process a sequence of inputs and return all hidden states.

Input:

  • x: input sequence of shape (seq_len, input_dim)
  • W_xh: input-to-hidden weights of shape (input_dim, hidden_dim)
  • W_hh: hidden-to-hidden weights of shape (hidden_dim, hidden_dim)
  • b_h: hidden bias of shape (hidden_dim,)
  • h0: initial hidden state of shape (hidden_dim,)

Output: All hidden states, shape (seq_len, hidden_dim).

Hints

rnn recurrent sequence-model tanh
Detecting runtime...