We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
GRU Cell Step
Why this matters
Before Transformers ate the world, gated recurrent units (GRUs) and LSTMs were the workhorse of sequence modeling. They still show up in:
- Edge / on-device speech models (cheap, sequential).
- Time-series forecasting (small models, long horizons).
- Differential equation solvers (NeuralODE-style).
- Encoder-decoder bottlenecks where you want a single state vector.
Implementing one cell from scratch with self.param shows you
exactly how the gating math works — and reinforces the
“concatenate inputs and hidden, then linear” pattern that’s
everywhere in RNN-land.
Reset, update, candidate
A GRU step takes the previous hidden state h_prev ∈ R^H plus a
new input x ∈ R^D and produces a new hidden state h ∈ R^H.
There are three pieces:
r = σ(W_r [x; h_prev] + b_r) # reset gate (H,)
z = σ(W_z [x; h_prev] + b_z) # update gate (H,)
h_tilde = tanh(W_h [x; r * h_prev] + b_h) # candidate (H,)
h = (1 - z) * h_prev + z * h_tilde # new hidden (H,)
The update gate z controls how much of the candidate to MIX
INTO the previous state. z=0 keeps h_prev unchanged
(perfect long-term memory); z=1 fully replaces it.
The reset gate r controls how much of h_prev is allowed
INTO the candidate computation. r=0 lets the candidate ignore
history (start fresh); r=1 uses it normally.
The [x; h] concat trick
Instead of two separate weight matrices W_x and W_h, GRU
implementations almost always concatenate x and h and use a
single weight matrix per gate of shape (D + H, H):
xh = jnp.concatenate([x, h_prev], axis=-1) # (D + H,)
r = sigmoid(xh @ W_r + b_r) # (H,)
Mathematically equivalent to x @ W_xr + h @ W_hr + b_r, but
implementation-wise it’s one matmul, simpler indexing, and tends
to be how reference papers state it.
NOTE: The candidate uses [x; r * h_prev] — the reset gate
masks the hidden BEFORE concatenation.
Worked walk-through
D=3, H=4, x=[0.1, 0.2, 0.3], h_prev=[0, 0, 0, 0]:
-
xh = [0.1, 0.2, 0.3, 0, 0, 0, 0]shape(7,). -
r = sigmoid(xh @ W_r + b_r)→(4,). -
z = sigmoid(xh @ W_z + b_z)→(4,). -
xrh = [0.1, 0.2, 0.3, 0, 0, 0, 0](sincer * 0 = 0) shape(7,). -
h_tilde = tanh(xrh @ W_h + b_h)→(4,). -
h = (1 - z) * h_prev + z * h_tilde = z * h_tilde(sinceh_prev = 0).
Three weight matrices (D+H, H) and three biases (H,). Six
parameters total in the cell.
Common pitfalls
-
(1-z) * h_tilde + z * h_previnstead of the canonical form — it’s a sign convention, but the standard is(1-z) * h_prev + z * h_tilde. Doing it backwards makesz=0mean “fully replace,” which inverts the gate semantics. -
Forgetting
r * h_previn the candidate: passes the raw hidden through, losing the reset gate’s role. -
Wrong activations: gates are SIGMOID (
(0, 1)mixing coefficients); candidate is TANH ((-1, 1)proposed values). - Single shared weight matrix for all three pieces: each gate / candidate has its own params. Never share.
Problem
Implement gru_cell_step(seed, x, h_prev, hidden) as a single
GRU step from scratch using self.param:
-
MyGRUCell(nn.Module)withhiddenfield. -
Inside
@nn.compact: declare three weight matrices(D+H, H)and three biases(H,)forr,z,h_tilde. - Compute the four equations above.
-
Init/apply, return new
hflattened.
Init weights with nn.initializers.lecun_normal(), biases with
nn.initializers.zeros.
Inputs:
-
seed: int. -
x: 1-D(D,). -
h_prev: 1-D(H,). -
hidden: int H.
Output: 1-D (H,) — the new hidden state.
Hints
Sign in to attempt this problem and view the solution.