We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Fori Loop in Module
Why this matters
Sometimes you need a loop inside a Module that’s:
- Bounded by a runtime value of T — you don’t know the number of iterations at trace time.
- Long enough that Python unrolling would blow up compile time.
- Doesn’t need per-step outputs stacked — just the final state.
jax.lax.fori_loop(lower, upper, body, init_val) is the JAX-level
primitive for this case: a single HLO While op, traced once,
runs upper - lower times. Compared to lax.scan it gives you
fewer guarantees (no xs to scan over, no ys stacked), but the
overhead is minimal.
The interesting part for nnx: how does fori_loop interact with
nnx.Param? The loop body must close over the params somehow.
The clean trick: extract the underlying JAX array via param.value
BEFORE the loop, close over that, and run the loop on pure arrays.
The recipe
class Recurrence(nnx.Module):
def __init__(self, hidden, *, rngs):
key = rngs.params()
self.W = nnx.Param(jax.random.normal(key, (hidden, hidden)) * 0.1)
def __call__(self, x, T):
W = self.W.value # pluck out the JAX array, plain pytree
def body(i, state):
return jnp.tanh(state @ W + x)
return jax.lax.fori_loop(0, T, body, x)
Why pluck out W = self.W.value?
-
self.Wis annnx.Variablewrapper. lax.fori_loop’s body operates on a pytree carry (state). Variables aren’t pure pytrees — they have identity, mutation semantics, etc. -
The closure trick: by extracting
W = self.W.valueBEFORE the loop, you trap the array as a closed-over constant in the tracing scope. Insidebody, onlystate(the input arg) andW(the closed-over array) are accessed. Both are clean.
Tanh recurrence as a concrete example
The body state ↦ tanh(state @ W + x) is a Vanilla RNN cell with
skip-x: at every step, mix the current state with the same input
x (driven externally) through a learned recurrence matrix W.
With small W * 0.1, the activation stays in tanh’s near-linear
regime; the state slowly converges to a fixed point determined by
x and W.
Running for T=4 iterations, you should see the state walk away
from the initial x toward that fixed point.
Why fori_loop, not Python for?
A Python for _ in range(T) would unroll the body T times in the
trace. Fine for T=4, problematic for T=1000. fori_loop traces
the body once and uses an HLO While to repeat it at runtime.
Why not lax.scan? scan also traces once, but stacks per-iteration
y outputs — extra memory you don’t need. fori_loop is the smaller
hammer when you only want the final carry.
Common pitfalls
-
Closing over
self.Wdirectly insidebody. The Variable wrapper has identity semantics; lax.fori_loop traces under the assumption that the carry’s pytree structure is fixed across iterations. Readingself.W.valueonce outside the loop avoids ambiguity. -
Mutating
self.Winsidebody. The lifted-transform contract requires the carry’s reference structure to be unchanged. Don’t do parameter updates inside fori_loop. -
Off-by-one in lower/upper.
fori_loop(0, T, ...)runs T times;fori_loop(1, T, ...)runs T-1 times. Theiindex passed tobodystarts atlowerand ends atupper - 1.
Problem
Implement fori_loop_recurrence(seed, x, hidden, T):
-
Define a
RecurrenceModule with:class Recurrence(nnx.Module): def __init__(self, hidden, *, rngs): key = rngs.params() self.W = nnx.Param(jax.random.normal(key, (hidden, hidden)) * 0.1) self.hidden = hidden def __call__(self, x, T): W = self.W.value def body(i, state): return jnp.tanh(state @ W + x) return jax.lax.fori_loop(0, T, body, x) -
Build the model with
int(hidden)andnnx.Rngs(int(seed)). -
Call
model(x, int(T))and returnout.reshape(-1).
Inputs:
-
seed: float (cast to int). -
x: 1-D(hidden,). -
hidden: float (cast to int) — H. -
T: float (cast to int) — number of fori_loop iterations.
Output: 1-D (H,) — final state.
Hints
Sign in to attempt this problem and view the solution.