hard primitives

NNX Scan RNN

Why this matters

Unrolling an RNN in Python is brutal:

h = jnp.zeros((H,))
for t in range(T):
    h, _ = cell(h, x[t])     # T separate Python ops

For T=2048, the JAX trace records 2048 distinct cell calls, blowing up compile time and produced HLO size. The fix at the JAX level is jax.lax.scan: it replaces the Python loop with a single HLO Scan, traces the body ONCE, and unrolls inside XLA.

The interesting part for nnx: lax.scan doesn’t know about Flax Variables. But it does know that you can close over arbitrary Python objects in the step function — Python closure works inside JAX traces, as long as you don’t try to mutate the closed-over state in ways jit can’t handle.

For an nnx GRUCell, that closure trick is exactly what we want: define a step(carry, x_t) that calls the cell, and let the cell’s params (which are nnx.Params living on the closed-over instance) flow into the trace as constants for the duration of the scan.

The recipe

cell = nnx.GRUCell(in_features=D, hidden_features=H, rngs=nnx.Rngs(seed))
init_carry = jnp.zeros((H,))

def step(carry, x_t):
    new_carry, y = cell(carry, x_t)
    return new_carry, y

final_carry, ys = jax.lax.scan(step, init_carry, x)   # x: (T, D)

What lax.scan returns:

  • final_carry — the last hidden state. Shape (H,).
  • ys — the per-step outputs stacked along axis 0. Shape (T, H).

For this problem we return only final_carry.

Why this works in spite of “stateful” nnx

The cell object holds nnx.Params like cell.dense_h.kernel. Inside step, cell(carry, x_t) reads those params via attribute access — param.value is a JAX array. lax.scan, when tracing step, captures the closure (including cell); when the outer array operations actually run, the params are looked up at the time of execution. They don’t change during the scan (we’re not training inside the scan), so this is safe.

For training cases where you DO want grads w.r.t. the cell params, you’d use nnx.value_and_grad outside the scan: the scan is the forward, and grad propagates through it like any other JAX op.

Alternative: nnx.scan

Flax NNX also provides nnx.scan, the lifted version. It handles the split/merge for you and supports more nuanced patterns (e.g., the carry being a Module). For the simple “close over a stateful cell” case, plain jax.lax.scan is the smaller hammer; for cases where the carry is itself an nnx Module that gets mutated, prefer nnx.scan.

Common pitfalls

  • Wrong layout for x. lax.scan iterates over axis 0; if your input is (D, T), transpose first or it’ll iterate over D.
  • Building the cell INSIDE step. Each scan step would re-init the cell’s params with a fresh RNG, scrambling the state. Build the cell ONCE outside.
  • Returning a Python int T from step. The body must return pytrees of arrays; T (a Python int) would be fine as a static closed-over constant, but as a return it confuses the carry’s pytree shape between iterations.

Problem

Implement scan_rnn_forward(seed, x, hidden, T):

  1. Cast H = int(hidden), Tt = int(T), s = int(seed).
  2. Build cell = nnx.GRUCell(in_features=x.shape[-1], hidden_features=H, rngs=nnx.Rngs(s)).
  3. init_carry = jnp.zeros((H,)).
  4. Define step(carry, x_t) that calls cell(carry, x_t) and returns (new_carry, y).
  5. final_carry, _ = jax.lax.scan(step, init_carry, x[:Tt]).
  6. Return final_carry.reshape(-1).

Inputs:

  • seed: float (cast to int).
  • x: 2-D (T_total, D) — input sequence (we slice the first Tt timesteps).
  • hidden: float (cast to int) — H.
  • T: float (cast to int) — number of timesteps to scan.

Output: 1-D (H,) — final hidden state after Tt timesteps.

Hints

flax nnx lifted-transforms scan rnn

Sign in to attempt this problem and view the solution.