We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Scan Layers
Why this matters
A 96-layer transformer has 96 distinct Block Modules. Compiling
that means JAX traces 96 copies of the block’s HLO graph, even
though every block is structurally identical. For training,
that’s 96× the compile time and 96× the memory the optimizer-state
pytree has to carry.
The trick: scan over layers. Stack the per-layer params along a
new leading axis (size = num_layers), scan the per-layer forward
along that axis, and the trace records ONE copy of the block.
Compile time drops from O(L) to O(1). Each layer still has its own
weights — they’re just stored as (L, …) arrays rather than as L
separate (…) arrays.
The previous problem used lax.scan over a time axis (one shared
cell, T inputs). This problem uses nnx.scan over a layer axis
(L distinct blocks, one input that flows through them).
Step 1: build a stacked block
Same trick as the ensemble problem — split_rngs + vmap over the
constructor:
@nnx.split_rngs(splits=N)
@nnx.vmap(in_axes=(0,), out_axes=0)
def make_block(rngs):
return nnx.Linear(F, F, rngs=rngs)
blocks = make_block(nnx.Rngs(seed))
# blocks.kernel.value.shape == (N, F, F)
# blocks.bias.value.shape == (N, F)
Step 2: scan along the layer axis
nnx.scan is the lifted jax.lax.scan. It distinguishes two roles
via nnx.Carry:
-
nnx.Carry— this argument flows through the loop. Its shape is preserved across iterations. (Likeinit_carryandfinal_carryinlax.scan.) -
any axis index (e.g.
0) — this argument is sliced along that axis on each iteration. The slice is fed to the body.
So:
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
def forward(x, block):
return jax.nn.relu(block(x))
out = forward(x, blocks)
forward is written for ONE layer: takes x (the carry — what
we’re transforming) and block (one per-layer slice of the stacked
blocks). At each iteration, scan picks blocks[i], calls
forward(x, blocks[i]), and uses the return value as the new x
for the next iteration. After N iterations, out is the final
activation.
out_axes=nnx.Carry says “the output is the new carry” — exactly
one returned value, used as the next x. (If we also wanted the
per-layer outputs stacked, we’d write out_axes=(nnx.Carry, 0).)
Why this is faster
XLA sees a single block’s HLO and a Scan over it. The compile is O(1) in num_layers. At runtime, each iteration runs the same kernel with sliced params — typical speedup over Python-unrolled stacks is a 10× compile-time reduction for a transformer-sized model.
When you can’t use this
Only blocks that are structurally identical and with the same
shape contract across layers can be scanned this way. If layer i
has different features_out than layer i+1 (e.g., a stem→trunk→head
architecture), the per-layer params don’t stack into a uniform
pytree. Scan that with manually-grouped homogeneous segments.
Common pitfalls
-
Forgetting
split_rngs. All N layers would share an init RNG and end up with identical params. Effectively a 1-layer net. -
Using
(0, nnx.Carry)instead of(nnx.Carry, 0). Order matters: the function’s first arg is the carry; second is the per-iteration slice. -
Mismatched in/out axes.
out_axes=nnx.Carry(just the carry) vsout_axes=(nnx.Carry, 0)(carry + per-layer stacked output). Pick deliberately.
Problem
Implement scan_layers_forward(seed, x, num_layers, features):
-
Cast
N = int(num_layers),F = int(features),s = int(seed). -
Build a stacked block via
split_rngs(splits=N)+nnx.vmapover a constructor that returnsnnx.Linear(F, F, rngs=rngs). -
Define
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)def forward(x, block): return jax.nn.relu(block(x)). -
Call
forward(x, blocks)to get the final activation(F,). -
Return
out.reshape(-1).
Inputs:
-
seed: float (cast to int). -
x: 1-D(F,)— the input vector. Shape is preserved through layers. -
num_layers: float (cast to int) — N. -
features: float (cast to int) — F (each layer isF → F).
Output: 1-D (F,) — final activation.
Hints
Sign in to attempt this problem and view the solution.