medium primitives

GRU Cell Step

Why this matters

Before Transformers ate the world, gated recurrent units (GRUs) and LSTMs were the workhorse of sequence modeling. They still show up in:

  • Edge / on-device speech models (cheap, sequential).
  • Time-series forecasting (small models, long horizons).
  • Differential equation solvers (NeuralODE-style).
  • Encoder-decoder bottlenecks where you want a single state vector.

Implementing one cell from scratch with self.param shows you exactly how the gating math works — and reinforces the “concatenate inputs and hidden, then linear” pattern that’s everywhere in RNN-land.

Reset, update, candidate

A GRU step takes the previous hidden state h_prev ∈ R^H plus a new input x ∈ R^D and produces a new hidden state h ∈ R^H. There are three pieces:

r       = σ(W_r [x; h_prev] + b_r)        # reset gate    (H,)
z       = σ(W_z [x; h_prev] + b_z)        # update gate   (H,)
h_tilde = tanh(W_h [x; r * h_prev] + b_h) # candidate     (H,)
h       = (1 - z) * h_prev + z * h_tilde  # new hidden    (H,)

The update gate z controls how much of the candidate to MIX INTO the previous state. z=0 keeps h_prev unchanged (perfect long-term memory); z=1 fully replaces it.

The reset gate r controls how much of h_prev is allowed INTO the candidate computation. r=0 lets the candidate ignore history (start fresh); r=1 uses it normally.

The [x; h] concat trick

Instead of two separate weight matrices W_x and W_h, GRU implementations almost always concatenate x and h and use a single weight matrix per gate of shape (D + H, H):

xh = jnp.concatenate([x, h_prev], axis=-1)  # (D + H,)
r  = sigmoid(xh @ W_r + b_r)                # (H,)

Mathematically equivalent to x @ W_xr + h @ W_hr + b_r, but implementation-wise it’s one matmul, simpler indexing, and tends to be how reference papers state it.

NOTE: The candidate uses [x; r * h_prev] — the reset gate masks the hidden BEFORE concatenation.

Worked walk-through

D=3, H=4, x=[0.1, 0.2, 0.3], h_prev=[0, 0, 0, 0]:

  1. xh = [0.1, 0.2, 0.3, 0, 0, 0, 0] shape (7,).
  2. r = sigmoid(xh @ W_r + b_r)(4,).
  3. z = sigmoid(xh @ W_z + b_z)(4,).
  4. xrh = [0.1, 0.2, 0.3, 0, 0, 0, 0] (since r * 0 = 0) shape (7,).
  5. h_tilde = tanh(xrh @ W_h + b_h)(4,).
  6. h = (1 - z) * h_prev + z * h_tilde = z * h_tilde (since h_prev = 0).

Three weight matrices (D+H, H) and three biases (H,). Six parameters total in the cell.

Common pitfalls

  • (1-z) * h_tilde + z * h_prev instead of the canonical form — it’s a sign convention, but the standard is (1-z) * h_prev + z * h_tilde. Doing it backwards makes z=0 mean “fully replace,” which inverts the gate semantics.
  • Forgetting r * h_prev in the candidate: passes the raw hidden through, losing the reset gate’s role.
  • Wrong activations: gates are SIGMOID ((0, 1) mixing coefficients); candidate is TANH ((-1, 1) proposed values).
  • Single shared weight matrix for all three pieces: each gate / candidate has its own params. Never share.

Problem

Implement gru_cell_step(seed, x, h_prev, hidden) as a single GRU step from scratch using self.param:

  1. MyGRUCell(nn.Module) with hidden field.
  2. Inside @nn.compact: declare three weight matrices (D+H, H) and three biases (H,) for r, z, h_tilde.
  3. Compute the four equations above.
  4. Init/apply, return new h flattened.

Init weights with nn.initializers.lecun_normal(), biases with nn.initializers.zeros.

Inputs:

  • seed: int.
  • x: 1-D (D,).
  • h_prev: 1-D (H,).
  • hidden: int H.

Output: 1-D (H,) — the new hidden state.

Hints

flax gru rnn self-param

Sign in to attempt this problem and view the solution.