hard primitives

NNX Scan Layers

Why this matters

A 96-layer transformer has 96 distinct Block Modules. Compiling that means JAX traces 96 copies of the block’s HLO graph, even though every block is structurally identical. For training, that’s 96× the compile time and 96× the memory the optimizer-state pytree has to carry.

The trick: scan over layers. Stack the per-layer params along a new leading axis (size = num_layers), scan the per-layer forward along that axis, and the trace records ONE copy of the block. Compile time drops from O(L) to O(1). Each layer still has its own weights — they’re just stored as (L, …) arrays rather than as L separate (…) arrays.

The previous problem used lax.scan over a time axis (one shared cell, T inputs). This problem uses nnx.scan over a layer axis (L distinct blocks, one input that flows through them).

Step 1: build a stacked block

Same trick as the ensemble problem — split_rngs + vmap over the constructor:

@nnx.split_rngs(splits=N)
@nnx.vmap(in_axes=(0,), out_axes=0)
def make_block(rngs):
    return nnx.Linear(F, F, rngs=rngs)

blocks = make_block(nnx.Rngs(seed))
# blocks.kernel.value.shape == (N, F, F)
# blocks.bias.value.shape   == (N, F)

Step 2: scan along the layer axis

nnx.scan is the lifted jax.lax.scan. It distinguishes two roles via nnx.Carry:

  • nnx.Carry — this argument flows through the loop. Its shape is preserved across iterations. (Like init_carry and final_carry in lax.scan.)
  • any axis index (e.g. 0) — this argument is sliced along that axis on each iteration. The slice is fed to the body.

So:

@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
def forward(x, block):
    return jax.nn.relu(block(x))

out = forward(x, blocks)

forward is written for ONE layer: takes x (the carry — what we’re transforming) and block (one per-layer slice of the stacked blocks). At each iteration, scan picks blocks[i], calls forward(x, blocks[i]), and uses the return value as the new x for the next iteration. After N iterations, out is the final activation.

out_axes=nnx.Carry says “the output is the new carry” — exactly one returned value, used as the next x. (If we also wanted the per-layer outputs stacked, we’d write out_axes=(nnx.Carry, 0).)

Why this is faster

XLA sees a single block’s HLO and a Scan over it. The compile is O(1) in num_layers. At runtime, each iteration runs the same kernel with sliced params — typical speedup over Python-unrolled stacks is a 10× compile-time reduction for a transformer-sized model.

When you can’t use this

Only blocks that are structurally identical and with the same shape contract across layers can be scanned this way. If layer i has different features_out than layer i+1 (e.g., a stem→trunk→head architecture), the per-layer params don’t stack into a uniform pytree. Scan that with manually-grouped homogeneous segments.

Common pitfalls

  • Forgetting split_rngs. All N layers would share an init RNG and end up with identical params. Effectively a 1-layer net.
  • Using (0, nnx.Carry) instead of (nnx.Carry, 0). Order matters: the function’s first arg is the carry; second is the per-iteration slice.
  • Mismatched in/out axes. out_axes=nnx.Carry (just the carry) vs out_axes=(nnx.Carry, 0) (carry + per-layer stacked output). Pick deliberately.

Problem

Implement scan_layers_forward(seed, x, num_layers, features):

  1. Cast N = int(num_layers), F = int(features), s = int(seed).
  2. Build a stacked block via split_rngs(splits=N) + nnx.vmap over a constructor that returns nnx.Linear(F, F, rngs=rngs).
  3. Define @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry) def forward(x, block): return jax.nn.relu(block(x)).
  4. Call forward(x, blocks) to get the final activation (F,).
  5. Return out.reshape(-1).

Inputs:

  • seed: float (cast to int).
  • x: 1-D (F,) — the input vector. Shape is preserved through layers.
  • num_layers: float (cast to int) — N.
  • features: float (cast to int) — F (each layer is F → F).

Output: 1-D (F,) — final activation.

Hints

flax nnx lifted-transforms scan deep-net

Sign in to attempt this problem and view the solution.