hard primitives

NNX Fori Loop in Module

Why this matters

Sometimes you need a loop inside a Module that’s:

  • Bounded by a runtime value of T — you don’t know the number of iterations at trace time.
  • Long enough that Python unrolling would blow up compile time.
  • Doesn’t need per-step outputs stacked — just the final state.

jax.lax.fori_loop(lower, upper, body, init_val) is the JAX-level primitive for this case: a single HLO While op, traced once, runs upper - lower times. Compared to lax.scan it gives you fewer guarantees (no xs to scan over, no ys stacked), but the overhead is minimal.

The interesting part for nnx: how does fori_loop interact with nnx.Param? The loop body must close over the params somehow. The clean trick: extract the underlying JAX array via param.value BEFORE the loop, close over that, and run the loop on pure arrays.

The recipe

class Recurrence(nnx.Module):
    def __init__(self, hidden, *, rngs):
        key = rngs.params()
        self.W = nnx.Param(jax.random.normal(key, (hidden, hidden)) * 0.1)

    def __call__(self, x, T):
        W = self.W.value     # pluck out the JAX array, plain pytree
        def body(i, state):
            return jnp.tanh(state @ W + x)
        return jax.lax.fori_loop(0, T, body, x)

Why pluck out W = self.W.value?

  • self.W is an nnx.Variable wrapper. lax.fori_loop’s body operates on a pytree carry (state). Variables aren’t pure pytrees — they have identity, mutation semantics, etc.
  • The closure trick: by extracting W = self.W.value BEFORE the loop, you trap the array as a closed-over constant in the tracing scope. Inside body, only state (the input arg) and W (the closed-over array) are accessed. Both are clean.

Tanh recurrence as a concrete example

The body state ↦ tanh(state @ W + x) is a Vanilla RNN cell with skip-x: at every step, mix the current state with the same input x (driven externally) through a learned recurrence matrix W. With small W * 0.1, the activation stays in tanh’s near-linear regime; the state slowly converges to a fixed point determined by x and W.

Running for T=4 iterations, you should see the state walk away from the initial x toward that fixed point.

Why fori_loop, not Python for?

A Python for _ in range(T) would unroll the body T times in the trace. Fine for T=4, problematic for T=1000. fori_loop traces the body once and uses an HLO While to repeat it at runtime.

Why not lax.scan? scan also traces once, but stacks per-iteration y outputs — extra memory you don’t need. fori_loop is the smaller hammer when you only want the final carry.

Common pitfalls

  • Closing over self.W directly inside body. The Variable wrapper has identity semantics; lax.fori_loop traces under the assumption that the carry’s pytree structure is fixed across iterations. Reading self.W.value once outside the loop avoids ambiguity.
  • Mutating self.W inside body. The lifted-transform contract requires the carry’s reference structure to be unchanged. Don’t do parameter updates inside fori_loop.
  • Off-by-one in lower/upper. fori_loop(0, T, ...) runs T times; fori_loop(1, T, ...) runs T-1 times. The i index passed to body starts at lower and ends at upper - 1.

Problem

Implement fori_loop_recurrence(seed, x, hidden, T):

  1. Define a Recurrence Module with:

    class Recurrence(nnx.Module):
        def __init__(self, hidden, *, rngs):
            key = rngs.params()
            self.W = nnx.Param(jax.random.normal(key, (hidden, hidden)) * 0.1)
            self.hidden = hidden
        def __call__(self, x, T):
            W = self.W.value
            def body(i, state):
                return jnp.tanh(state @ W + x)
            return jax.lax.fori_loop(0, T, body, x)
  2. Build the model with int(hidden) and nnx.Rngs(int(seed)).

  3. Call model(x, int(T)) and return out.reshape(-1).

Inputs:

  • seed: float (cast to int).
  • x: 1-D (hidden,).
  • hidden: float (cast to int) — H.
  • T: float (cast to int) — number of fori_loop iterations.

Output: 1-D (H,) — final state.

Hints

flax nnx lifted-transforms fori-loop rnn

Sign in to attempt this problem and view the solution.