We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Scan RNN
Why this matters
Unrolling an RNN in Python is brutal:
h = jnp.zeros((H,))
for t in range(T):
h, _ = cell(h, x[t]) # T separate Python ops
For T=2048, the JAX trace records 2048 distinct cell calls,
blowing up compile time and produced HLO size. The fix at the JAX
level is jax.lax.scan: it replaces the Python loop with a single
HLO Scan, traces the body ONCE, and unrolls inside XLA.
The interesting part for nnx: lax.scan doesn’t know about Flax
Variables. But it does know that you can close over arbitrary
Python objects in the step function — Python closure works inside
JAX traces, as long as you don’t try to mutate the closed-over state
in ways jit can’t handle.
For an nnx GRUCell, that closure trick is exactly what we want:
define a step(carry, x_t) that calls the cell, and let the cell’s
params (which are nnx.Params living on the closed-over instance)
flow into the trace as constants for the duration of the scan.
The recipe
cell = nnx.GRUCell(in_features=D, hidden_features=H, rngs=nnx.Rngs(seed))
init_carry = jnp.zeros((H,))
def step(carry, x_t):
new_carry, y = cell(carry, x_t)
return new_carry, y
final_carry, ys = jax.lax.scan(step, init_carry, x) # x: (T, D)
What lax.scan returns:
-
final_carry— the last hidden state. Shape(H,). -
ys— the per-step outputs stacked along axis 0. Shape(T, H).
For this problem we return only final_carry.
Why this works in spite of “stateful” nnx
The cell object holds nnx.Params like cell.dense_h.kernel.
Inside step, cell(carry, x_t) reads those params via attribute
access — param.value is a JAX array. lax.scan, when tracing
step, captures the closure (including cell); when the outer
array operations actually run, the params are looked up at the
time of execution. They don’t change during the scan (we’re not
training inside the scan), so this is safe.
For training cases where you DO want grads w.r.t. the cell params,
you’d use nnx.value_and_grad outside the scan: the scan is the
forward, and grad propagates through it like any other JAX op.
Alternative: nnx.scan
Flax NNX also provides nnx.scan, the lifted version. It handles
the split/merge for you and supports more nuanced patterns (e.g.,
the carry being a Module). For the simple “close over a stateful
cell” case, plain jax.lax.scan is the smaller hammer; for cases
where the carry is itself an nnx Module that gets mutated, prefer
nnx.scan.
Common pitfalls
-
Wrong layout for
x. lax.scan iterates over axis 0; if your input is(D, T), transpose first or it’ll iterate over D. -
Building the cell INSIDE
step. Each scan step would re-init the cell’s params with a fresh RNG, scrambling the state. Build the cell ONCE outside. -
Returning a Python int
Tfromstep. The body must return pytrees of arrays;T(a Python int) would be fine as a static closed-over constant, but as a return it confuses the carry’s pytree shape between iterations.
Problem
Implement scan_rnn_forward(seed, x, hidden, T):
-
Cast
H = int(hidden),Tt = int(T),s = int(seed). -
Build
cell = nnx.GRUCell(in_features=x.shape[-1], hidden_features=H, rngs=nnx.Rngs(s)). -
init_carry = jnp.zeros((H,)). -
Define
step(carry, x_t)that callscell(carry, x_t)and returns(new_carry, y). -
final_carry, _ = jax.lax.scan(step, init_carry, x[:Tt]). -
Return
final_carry.reshape(-1).
Inputs:
-
seed: float (cast to int). -
x: 2-D(T_total, D)— input sequence (we slice the firstTttimesteps). -
hidden: float (cast to int) — H. -
T: float (cast to int) — number of timesteps to scan.
Output: 1-D (H,) — final hidden state after Tt timesteps.
Hints
Sign in to attempt this problem and view the solution.