hard primitives

jax.lax.scan inside a Flax Module

Why this matters

Most of the time, when you want to roll a Flax Module over a time axis, you reach for nn.scan (pos 79) — it lifts a Module’s __call__ to a scanned version while handling param/RNG bookkeeping.

But sometimes you DON’T need the lift. If your loop body:

  • doesn’t allocate new Flax variables per step,
  • just consumes a single weight matrix already declared as a self.param,
  • and produces a pure-JAX computation (tanh, matmul, etc.),

then you can use plain jax.lax.scan inside the Module’s __call__, with the body closing over self.param("W", ...).

Pros:

  • Slightly less ceremony — no nn.scan wrapper, no thinking about variable_axes or split_rngs.
  • Cleaner mental model when the body is “pure JAX with one closed-over weight matrix”.

Cons:

  • You’re responsible for shape consistency yourself. nn.scan catches a few classes of bugs at lift time; raw lax.scan happily runs into them at trace time.
  • The body cannot allocate nn.Dense(...), self.param(...), etc. Anything Flax-variable-aware MUST live OUTSIDE the lax.scan body.

Pattern

class LaxScanModule(nn.Module):
    hidden: int
    T: int

    @nn.compact
    def __call__(self, x):
        H = self.hidden
        W = self.param("W", nn.initializers.lecun_normal(), (H, H))
        init_carry = jnp.tanh(x @ W)
        def body(carry, _):
            new_carry = jnp.tanh(carry @ W)
            return new_carry, None
        final, _ = jax.lax.scan(body, init_carry, jnp.arange(self.T))
        return final

Three things to notice:

  1. self.param("W", ...) is allocated ONCE, outside lax.scan. The body closes over Wlax.scan sees a constant from JAX’s perspective.
  2. The body returns (carry, output_or_None). With output_or_None = None, lax.scan collects an empty pytree of outputs.
  3. xs = jnp.arange(self.T) is a length-T dummy “input” that just drives the loop count. We don’t read its values; lax.scan needs SOMETHING with a leading axis of size T (or alternatively length=T).

Why no variable_broadcast / split_rngs?

Those knobs are how nn.scan decides whether to give the body per-step OR shared params. jax.lax.scan doesn’t know what a Flax variable is — it just sees a pure function with closed-over constants. So the question doesn’t arise: there’s exactly one W, used at every step, because the body literally references the same array N times.

Carry shape rules

lax.scan requires the body’s output carry to have the SAME pytree structure / shape / dtype as the input carry. In our case:

  • init_carry = jnp.tanh(x @ W) — shape (H,) since x is (H,) and W is (H, H).
  • body’s new_carry = jnp.tanh(carry @ W) — also (H,).

Match. If the input x were a different shape, you’d want to project it to (H,) first (here we sidestep that by requiring x.shape[-1] == hidden).

Common pitfalls

  • Allocating a nn.Dense inside body — Flax raises a “context error” because the lax.scan-traced body has no Module stack. Move all Flax variable allocation OUTSIDE the body.
  • Mismatched carry shapes — body returns a different shape than init_carry. JAX raises carry input and output must have equal types. Print shapes during dev.
  • Forgetting the (carry, output) tuple in the bodylax.scan requires this exact return shape. return new_carry alone is wrong; use return new_carry, None.

Problem

Implement lax_scan_in_module(seed, x, hidden, T):

  1. Define LaxScanModule(nn.Module) with fields hidden: int and T: int.
  2. In @nn.compact __call__(self, x):
    • Allocate W = self.param("W", nn.initializers.lecun_normal(), (hidden, hidden)).
    • init_carry = jnp.tanh(x @ W).
    • Define body(carry, _) that returns (jnp.tanh(carry @ W), None).
    • Run final, _ = jax.lax.scan(body, init_carry, jnp.arange(T)).
    • Return final.
  3. In lax_scan_in_module: build the model, init with PRNGKey(seed) and x, apply, and return the result.

Assume x.shape[-1] == hidden.

Inputs:

  • seed: int.
  • x: 1-D (hidden,) initial input.
  • hidden: int H.
  • T: int — number of recurrence steps.

Output: 1-D (H,) final hidden state.

Hints

flax jax-lax-scan recurrence primitives

Sign in to attempt this problem and view the solution.