hard primitives

nn.scan over an RNN cell

Why this matters

A naive RNN unroll in Python is brutal:

h = jnp.zeros((H,))
for t in range(T):
    h, _ = cell.apply(params, h, x[t])  # T separate Python ops

For T=2048 (a long context window), the JAX trace records 2048 distinct calls, balloons compile time, and bloats the program. jax.lax.scan solves the problem at the JAX level — but it doesn’t know about Flax params, so naively wrapping a Module’s apply in lax.scan requires you to thread the variable dict by hand.

nn.scan is the lifted version of lax.scan for Flax Modules. It takes a Module class, returns a new Module class whose __call__ internally lax.scans the per-step body — and does the right thing with params, batch_stats, and PRNG keys.

The canonical incantation

ScanGRU = nn.scan(
    nn.GRUCell,
    variable_broadcast="params",
    split_rngs={"params": False},
    in_axes=0,
    out_axes=0,
)
cell = ScanGRU(features=H)
init_carry = jnp.zeros((H,))
final_carry, ys = cell(init_carry, x)   # x: (T, D)

Three knobs to understand:

  • variable_broadcast="params" — params are SHARED across timesteps. One set of weights, applied at every t. (The other option, variable_axes={"params": 0}, would create T independent sets — see pos 80.)
  • split_rngs={"params": False} — the param-init RNG is also shared, so init produces ONE param tree, not T copies.
  • in_axes=0 / out_axes=0 — the time axis is leading. x[0] is step 0, x[1] is step 1, etc. Set to 1 if you prefer (D, T) layout.

What nn.scan returns at apply time

Like lax.scan, the lifted call returns (final_carry, ys):

  • final_carry is the last hidden state — what you’d use for classification heads or for chaining sequence chunks.
  • ys stacks the per-step outputs along the same leading axis as the input (here axis 0, shape (T, H)). For a GRU, the per-step output IS the new hidden state — so ys is the full hidden trajectory.

For this problem we return only the final carry.

Init has to know the carry shape

nn.scan runs lax.scan, which requires the carry’s shape and dtype before the loop starts — Flax can’t infer them from the Module signature. So at init time you pass an init_carry whose shape matches what the cell expects:

init_carry = jnp.zeros((H,))
params = cell.init(rng, init_carry, x)  # x has shape (T, D)

The init runs ONE timestep under the hood (a dry-run with T=1 semantics) to figure out param shapes. Cheap.

Common pitfalls

  • Forgetting variable_broadcast="params" — without it, you’d get T independent param sets, which is fine for some uses (per-layer params; pos 80) but wrong for unrolling an RNN where the same cell runs at every step.
  • split_rngs={"params": True} on an RNN — splits the init RNG across timesteps, producing T param sets the broadcast then has to flatten somehow. Use False for shared params, True only when you genuinely want per-step RNGs (e.g., dropout).
  • Passing x with the wrong leading axisin_axes=0 means x.shape[0] is the time dimension; if you pass (D, T), scan iterates over D, gets D (T,) slices, and your shapes go sideways.

Problem

Implement scan_rnn_forward(seed, x, hidden):

  1. Build ScanGRU = nn.scan(nn.GRUCell, variable_broadcast="params", split_rngs={"params": False}, in_axes=0, out_axes=0).
  2. Instantiate cell = ScanGRU(features=hidden).
  3. init_carry = jnp.zeros((hidden,)).
  4. params = cell.init(PRNGKey(seed), init_carry, x).
  5. Apply: final_carry, _ = cell.apply(params, init_carry, x).
  6. Return final_carry.reshape(-1).

Inputs:

  • seed: int.
  • x: 2-D (T, D) — input sequence.
  • hidden: int H.

Output: 1-D (H,) — the final hidden state after T timesteps.

Hints

flax nn-scan rnn lifted-transforms

Sign in to attempt this problem and view the solution.