We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
LSTM Cell Step
Why this matters
The LSTM (Hochreiter & Schmidhuber, 1997) is the older, slightly
fancier cousin of the GRU. It tracks two persistent vectors
instead of one: a hidden state h (the output) and a cell state
c (the long-term memory). The cell state passes through each
step nearly unchanged — multiplied by a forget gate, plus a small
additive update — which is what gives LSTMs their famous
long-range memory.
LSTMs powered every state-of-the-art seq2seq model from 2014-2017 (Google Translate, the original Show-and-Tell captioning model, early speech recognizers) until Transformers displaced them. They still appear in:
-
Latency-critical streaming inference (each step is
O(1)work). - On-device speech (small param counts, pure-state recurrence).
- Tasks where data is genuinely sequential and small.
The four gates
Each LSTM step takes (x, h_prev, c_prev) and produces (h, c):
f = σ(W_f [x; h_prev] + b_f) # forget gate (H,)
i = σ(W_i [x; h_prev] + b_i) # input gate (H,)
o = σ(W_o [x; h_prev] + b_o) # output gate (H,)
g = tanh(W_g [x; h_prev] + b_g) # candidate cell (H,)
c = f * c_prev + i * g # update cell (H,)
h = o * tanh(c) # squash + gate (H,)
Read it in plain language:
-
Forget what’s no longer relevant (
f * c_prev). -
Add something new the input wants to write (
i * g). -
Output a squashed view of the (new) cell, masked by
o.
The cell-state recurrence c = f * c_prev + i * g is purely
additive when f ≈ 1 and i ≈ 0 — gradients flow without
vanishing. That’s the whole point.
Same [x; h] concat trick as GRU
Like the GRU, the canonical implementation concatenates x and
h_prev once, then uses one (D + H, H) weight per gate:
xh = jnp.concatenate([x, h_prev], axis=-1) # (D + H,)
f = sigmoid(xh @ W_f + b_f)
Some implementations stack all four weights into one
(D + H, 4 * H) matrix and split the result; this problem keeps
them separate for clarity (8 params vs 2). Both are equivalent.
Worked walk-through
D=3, H=4, x=[0.1, 0.2, 0.3], h_prev=[0]*4, c_prev=[0]*4:
-
xh = [0.1, 0.2, 0.3, 0, 0, 0, 0]shape(7,). -
f, i, o = sigmoid(xh @ W_·)→ three(4,)gate vectors. -
g = tanh(xh @ W_g + b_g)→(4,). -
c = f * 0 + i * g = i * g→(4,). -
h = o * tanh(c)→(4,). -
Concatenate
[h, c]→(8,), return.
The output for this problem is the concatenation [h, c]
(length 2H) so a tester can verify both pieces.
Common pitfalls
-
Wrong gate activations:
f, i, oare SIGMOID;gis TANH. Easy to slip atanhwhere asigmoidbelongs and vice versa. -
h = tanh(c) * ovsh = o * tanh(c)— same thing, pointwise, but a common tab-completion mistake is to forget thetanh(c)entirely and writeh = o * c. That breaks everything. -
Sharing weights across gates: each of
W_f, W_i, W_o, W_gis independent. -
Initial cell state: tests pass
c_prev. Don’t assume it’s zero — use whatever’s given.
Problem
Implement lstm_cell_step(seed, x, h_prev, c_prev, hidden)
using self.param:
-
MyLSTMCell(nn.Module)withhiddenfield. -
Inside
@nn.compact: declare four weights(D+H, H)and four biases(H,)forf, i, o, g. -
Compute
f, i, o, g, thenc, thenh. -
Return
concat([h, c])flattened — length2H.
Init weights with nn.initializers.lecun_normal(), biases with
nn.initializers.zeros.
Inputs:
-
seed: int. -
x: 1-D(D,). -
h_prev: 1-D(H,). -
c_prev: 1-D(H,). -
hidden: int H.
Output: 1-D (2H,) — [h, c] concatenated.
Hints
Sign in to attempt this problem and view the solution.