medium primitives

LSTM Cell Step

Why this matters

The LSTM (Hochreiter & Schmidhuber, 1997) is the older, slightly fancier cousin of the GRU. It tracks two persistent vectors instead of one: a hidden state h (the output) and a cell state c (the long-term memory). The cell state passes through each step nearly unchanged — multiplied by a forget gate, plus a small additive update — which is what gives LSTMs their famous long-range memory.

LSTMs powered every state-of-the-art seq2seq model from 2014-2017 (Google Translate, the original Show-and-Tell captioning model, early speech recognizers) until Transformers displaced them. They still appear in:

  • Latency-critical streaming inference (each step is O(1) work).
  • On-device speech (small param counts, pure-state recurrence).
  • Tasks where data is genuinely sequential and small.

The four gates

Each LSTM step takes (x, h_prev, c_prev) and produces (h, c):

f = σ(W_f [x; h_prev] + b_f)    # forget gate    (H,)
i = σ(W_i [x; h_prev] + b_i)    # input gate     (H,)
o = σ(W_o [x; h_prev] + b_o)    # output gate    (H,)
g = tanh(W_g [x; h_prev] + b_g) # candidate cell (H,)

c = f * c_prev + i * g          # update cell    (H,)
h = o * tanh(c)                 # squash + gate  (H,)

Read it in plain language:

  1. Forget what’s no longer relevant (f * c_prev).
  2. Add something new the input wants to write (i * g).
  3. Output a squashed view of the (new) cell, masked by o.

The cell-state recurrence c = f * c_prev + i * g is purely additive when f ≈ 1 and i ≈ 0 — gradients flow without vanishing. That’s the whole point.

Same [x; h] concat trick as GRU

Like the GRU, the canonical implementation concatenates x and h_prev once, then uses one (D + H, H) weight per gate:

xh = jnp.concatenate([x, h_prev], axis=-1)   # (D + H,)
f  = sigmoid(xh @ W_f + b_f)

Some implementations stack all four weights into one (D + H, 4 * H) matrix and split the result; this problem keeps them separate for clarity (8 params vs 2). Both are equivalent.

Worked walk-through

D=3, H=4, x=[0.1, 0.2, 0.3], h_prev=[0]*4, c_prev=[0]*4:

  1. xh = [0.1, 0.2, 0.3, 0, 0, 0, 0] shape (7,).
  2. f, i, o = sigmoid(xh @ W_·) → three (4,) gate vectors.
  3. g = tanh(xh @ W_g + b_g)(4,).
  4. c = f * 0 + i * g = i * g(4,).
  5. h = o * tanh(c)(4,).
  6. Concatenate [h, c](8,), return.

The output for this problem is the concatenation [h, c] (length 2H) so a tester can verify both pieces.

Common pitfalls

  • Wrong gate activations: f, i, o are SIGMOID; g is TANH. Easy to slip a tanh where a sigmoid belongs and vice versa.
  • h = tanh(c) * o vs h = o * tanh(c) — same thing, pointwise, but a common tab-completion mistake is to forget the tanh(c) entirely and write h = o * c. That breaks everything.
  • Sharing weights across gates: each of W_f, W_i, W_o, W_g is independent.
  • Initial cell state: tests pass c_prev. Don’t assume it’s zero — use whatever’s given.

Problem

Implement lstm_cell_step(seed, x, h_prev, c_prev, hidden) using self.param:

  1. MyLSTMCell(nn.Module) with hidden field.
  2. Inside @nn.compact: declare four weights (D+H, H) and four biases (H,) for f, i, o, g.
  3. Compute f, i, o, g, then c, then h.
  4. Return concat([h, c]) flattened — length 2H.

Init weights with nn.initializers.lecun_normal(), biases with nn.initializers.zeros.

Inputs:

  • seed: int.
  • x: 1-D (D,).
  • h_prev: 1-D (H,).
  • c_prev: 1-D (H,).
  • hidden: int H.

Output: 1-D (2H,)[h, c] concatenated.

Hints

flax lstm rnn self-param

Sign in to attempt this problem and view the solution.