medium end_to_end

GRU Cell

Implement one step of a GRU cell and process a sequence.

GRU equations for each time step: $$r = \sigma(x_t \cdot W_{xr} + h_{t-1} \cdot W_{hr} + b_r)$$ — reset gate $$z = \sigma(x_t \cdot W_{xz} + h_{t-1} \cdot W_{hz} + b_z)$$ — update gate $$\tilde{h} = \tanh(x_t \cdot W_{xn} + (r \odot h_{t-1}) \cdot W_{hn} + b_n)$$ — candidate $$h_t = (1 - z) \odot \tilde{h} + z \odot h_{t-1}$$

Input:

  • x: shape (seq_len, input_dim)
  • W_xr, W_xz, W_xn: shape (input_dim, hidden_dim) each
  • W_hr, W_hz, W_hn: shape (hidden_dim, hidden_dim) each
  • b_r, b_z, b_n: shape (hidden_dim,) each
  • h0: shape (hidden_dim,)

Output: Final hidden state h_T of shape (hidden_dim,).

Hints

gru recurrent gates sequence-model
Detecting runtime...