hard primitives

NNX Composed Transforms

Why this matters

A “real” deep-net forward composes at least two lifts:

  • scan over layers — to keep compile time O(1) in network depth.
  • vmap over the batch — to write per-sample logic and lift it to a minibatch.

nnx.scan and nnx.vmap compose naturally — each is a pure transformation over Modules, so applying one then the other is well-defined. But composition has a subtlety: order matters. The inner lift wraps the module first; the outer lift wraps the already-lifted thing.

The pattern

# Inner: scan-over-layers (per sample)
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
def scan_layers(x, block):
    return jax.nn.relu(block(x))

# Outer: vmap-over-batch (lifts the per-sample function)
@nnx.vmap(in_axes=(0, None))
def batched(x, blocks):
    return scan_layers(x, blocks)

out = batched(x_batch, blocks)

Reading inside-out:

  1. scan_layers runs the L-layer forward for ONE sample. The params (blocks) get sliced along axis 0 internally; the carry (x) is rolled through.
  2. batched is a vmap over the batch dim. For each batch element, it calls scan_layers(x_i, blocks) with the same blocks shared across the batch (in_axes=(0, None)).

The result of batched(x_batch, blocks) is (B, F) — one final activation per sample.

Why “scan inside vmap” not “vmap inside scan”

Both work. The difference is HLO shape:

  • scan inside vmap: per-sample loops fused into one batched kernel. Each scan iter sees a batched matmul. This is what you want for normal training.
  • vmap inside scan: each layer does a vmap over the batch. Functionally identical for this problem; differs subtly for patterns where you’d want different vmap axes per layer (rare).

Idiomatic JAX puts vmap on the outside.

Building the stacked params (recap)

Same trick from earlier problems:

@nnx.split_rngs(splits=N)
@nnx.vmap(in_axes=(0,), out_axes=0)
def make_block(rngs):
    return nnx.Linear(F, F, rngs=rngs)

blocks = make_block(nnx.Rngs(seed))

blocks.kernel.value.shape == (N, F, F) after this.

Common pitfalls

  • Putting vmap inside scan accidentally. @nnx.scan(...) outside @nnx.vmap(...) instead of the other way around. The result will run, but on most workloads the layout is suboptimal.
  • Letting in_axes get mangled. (nnx.Carry, 0) for the scan, (0, None) for the vmap. Mixing them up gives shape errors.
  • Calling batched(x, blocks) with x of shape (F,) instead of (B, F). vmap requires the leading axis to exist.

Problem

Implement composed_transforms(seed, x_batch, num_layers, features):

  1. Build stacked blocks = make_block(nnx.Rngs(int(seed))) with int(num_layers) blocks of nnx.Linear(F, F).
  2. Define @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry) def scan_layers(x, block): return jax.nn.relu(block(x)).
  3. Define @nnx.vmap(in_axes=(0, None)) def batched(x, blocks): return scan_layers(x, blocks).
  4. Call batched(x_batch, blocks) to get (B, F).
  5. Return out.reshape(-1).

Inputs:

  • seed: float (cast to int).
  • x_batch: 2-D (B, F).
  • num_layers: float (cast to int) — N.
  • features: float (cast to int) — F.

Output: 1-D (B * F,) — flattened batched output.

Hints

flax nnx lifted-transforms scan vmap composition

Sign in to attempt this problem and view the solution.