We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
jax.lax.scan inside a Flax Module
Why this matters
Most of the time, when you want to roll a Flax Module over a time
axis, you reach for nn.scan (pos 79) — it lifts a Module’s
__call__ to a scanned version while handling param/RNG bookkeeping.
But sometimes you DON’T need the lift. If your loop body:
- doesn’t allocate new Flax variables per step,
-
just consumes a single weight matrix already declared as a
self.param, -
and produces a pure-JAX computation (
tanh,matmul, etc.),
then you can use plain jax.lax.scan inside the Module’s
__call__, with the body closing over self.param("W", ...).
Pros:
-
Slightly less ceremony — no
nn.scanwrapper, no thinking aboutvariable_axesorsplit_rngs. - Cleaner mental model when the body is “pure JAX with one closed-over weight matrix”.
Cons:
-
You’re responsible for shape consistency yourself.
nn.scancatches a few classes of bugs at lift time; rawlax.scanhappily runs into them at trace time. -
The body cannot allocate
nn.Dense(...),self.param(...), etc. Anything Flax-variable-aware MUST live OUTSIDE thelax.scanbody.
Pattern
class LaxScanModule(nn.Module):
hidden: int
T: int
@nn.compact
def __call__(self, x):
H = self.hidden
W = self.param("W", nn.initializers.lecun_normal(), (H, H))
init_carry = jnp.tanh(x @ W)
def body(carry, _):
new_carry = jnp.tanh(carry @ W)
return new_carry, None
final, _ = jax.lax.scan(body, init_carry, jnp.arange(self.T))
return final
Three things to notice:
-
self.param("W", ...)is allocated ONCE, outsidelax.scan. The body closes overW—lax.scansees a constant from JAX’s perspective. -
The body returns
(carry, output_or_None). Withoutput_or_None = None,lax.scancollects an empty pytree of outputs. -
xs = jnp.arange(self.T)is a length-T dummy “input” that just drives the loop count. We don’t read its values;lax.scanneeds SOMETHING with a leading axis of size T (or alternativelylength=T).
Why no variable_broadcast / split_rngs?
Those knobs are how nn.scan decides whether to give the body
per-step OR shared params. jax.lax.scan doesn’t know what
a Flax variable is — it just sees a pure function with closed-over
constants. So the question doesn’t arise: there’s exactly one W,
used at every step, because the body literally references the same
array N times.
Carry shape rules
lax.scan requires the body’s output carry to have the SAME pytree
structure / shape / dtype as the input carry. In our case:
-
init_carry = jnp.tanh(x @ W)— shape(H,)sincexis(H,)andWis(H, H). -
body’s
new_carry = jnp.tanh(carry @ W)— also(H,).
Match. If the input x were a different shape, you’d want to project
it to (H,) first (here we sidestep that by requiring x.shape[-1] == hidden).
Common pitfalls
-
Allocating a
nn.Denseinsidebody— Flax raises a “context error” because thelax.scan-traced body has no Module stack. Move all Flax variable allocation OUTSIDE the body. -
Mismatched carry shapes — body returns a different shape than
init_carry. JAX raisescarry input and output must have equal types. Print shapes during dev. -
Forgetting the
(carry, output)tuple in the body —lax.scanrequires this exact return shape.return new_carryalone is wrong; usereturn new_carry, None.
Problem
Implement lax_scan_in_module(seed, x, hidden, T):
-
Define
LaxScanModule(nn.Module)with fieldshidden: intandT: int. -
In
@nn.compact __call__(self, x):-
Allocate
W = self.param("W", nn.initializers.lecun_normal(), (hidden, hidden)). -
init_carry = jnp.tanh(x @ W). -
Define
body(carry, _)that returns(jnp.tanh(carry @ W), None). -
Run
final, _ = jax.lax.scan(body, init_carry, jnp.arange(T)). -
Return
final.
-
Allocate
-
In
lax_scan_in_module: build the model, init withPRNGKey(seed)andx, apply, and return the result.
Assume x.shape[-1] == hidden.
Inputs:
-
seed: int. -
x: 1-D(hidden,)initial input. -
hidden: int H. -
T: int — number of recurrence steps.
Output: 1-D (H,) final hidden state.
Hints
Sign in to attempt this problem and view the solution.