We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
nn.scan over an RNN cell
Why this matters
A naive RNN unroll in Python is brutal:
h = jnp.zeros((H,))
for t in range(T):
h, _ = cell.apply(params, h, x[t]) # T separate Python ops
For T=2048 (a long context window), the JAX trace records 2048
distinct calls, balloons compile time, and bloats the program.
jax.lax.scan solves the problem at the JAX level — but it doesn’t
know about Flax params, so naively wrapping a Module’s apply in
lax.scan requires you to thread the variable dict by hand.
nn.scan is the lifted version of lax.scan for Flax Modules.
It takes a Module class, returns a new Module class whose __call__
internally lax.scans the per-step body — and does the right thing
with params, batch_stats, and PRNG keys.
The canonical incantation
ScanGRU = nn.scan(
nn.GRUCell,
variable_broadcast="params",
split_rngs={"params": False},
in_axes=0,
out_axes=0,
)
cell = ScanGRU(features=H)
init_carry = jnp.zeros((H,))
final_carry, ys = cell(init_carry, x) # x: (T, D)
Three knobs to understand:
-
variable_broadcast="params"— params are SHARED across timesteps. One set of weights, applied at every t. (The other option,variable_axes={"params": 0}, would create T independent sets — see pos 80.) -
split_rngs={"params": False}— the param-init RNG is also shared, so init produces ONE param tree, not T copies. -
in_axes=0/out_axes=0— the time axis is leading.x[0]is step 0,x[1]is step 1, etc. Set to1if you prefer(D, T)layout.
What nn.scan returns at apply time
Like lax.scan, the lifted call returns (final_carry, ys):
-
final_carryis the last hidden state — what you’d use for classification heads or for chaining sequence chunks. -
ysstacks the per-step outputs along the same leading axis as the input (here axis 0, shape(T, H)). For a GRU, the per-step output IS the new hidden state — soysis the full hidden trajectory.
For this problem we return only the final carry.
Init has to know the carry shape
nn.scan runs lax.scan, which requires the carry’s shape and
dtype before the loop starts — Flax can’t infer them from the
Module signature. So at init time you pass an init_carry whose
shape matches what the cell expects:
init_carry = jnp.zeros((H,))
params = cell.init(rng, init_carry, x) # x has shape (T, D)
The init runs ONE timestep under the hood (a dry-run with T=1
semantics) to figure out param shapes. Cheap.
Common pitfalls
-
Forgetting
variable_broadcast="params"— without it, you’d get T independent param sets, which is fine for some uses (per-layer params; pos 80) but wrong for unrolling an RNN where the same cell runs at every step. -
split_rngs={"params": True}on an RNN — splits the init RNG across timesteps, producing T param sets the broadcast then has to flatten somehow. UseFalsefor shared params,Trueonly when you genuinely want per-step RNGs (e.g., dropout). -
Passing
xwith the wrong leading axis —in_axes=0meansx.shape[0]is the time dimension; if you pass(D, T), scan iterates over D, gets D(T,)slices, and your shapes go sideways.
Problem
Implement scan_rnn_forward(seed, x, hidden):
-
Build
ScanGRU = nn.scan(nn.GRUCell, variable_broadcast="params", split_rngs={"params": False}, in_axes=0, out_axes=0). -
Instantiate
cell = ScanGRU(features=hidden). -
init_carry = jnp.zeros((hidden,)). -
params = cell.init(PRNGKey(seed), init_carry, x). -
Apply:
final_carry, _ = cell.apply(params, init_carry, x). -
Return
final_carry.reshape(-1).
Inputs:
-
seed: int. -
x: 2-D(T, D)— input sequence. -
hidden: int H.
Output: 1-D (H,) — the final hidden state after T timesteps.
Hints
Sign in to attempt this problem and view the solution.