We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Composed Transforms
Why this matters
A “real” deep-net forward composes at least two lifts:
- scan over layers — to keep compile time O(1) in network depth.
- vmap over the batch — to write per-sample logic and lift it to a minibatch.
nnx.scan and nnx.vmap compose naturally — each is a pure
transformation over Modules, so applying one then the other is
well-defined. But composition has a subtlety: order matters.
The inner lift wraps the module first; the outer lift wraps the
already-lifted thing.
The pattern
# Inner: scan-over-layers (per sample)
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
def scan_layers(x, block):
return jax.nn.relu(block(x))
# Outer: vmap-over-batch (lifts the per-sample function)
@nnx.vmap(in_axes=(0, None))
def batched(x, blocks):
return scan_layers(x, blocks)
out = batched(x_batch, blocks)
Reading inside-out:
-
scan_layersruns the L-layer forward for ONE sample. The params (blocks) get sliced along axis 0 internally; the carry (x) is rolled through. -
batchedis a vmap over the batch dim. For each batch element, it callsscan_layers(x_i, blocks)with the sameblocksshared across the batch (in_axes=(0, None)).
The result of batched(x_batch, blocks) is (B, F) — one final
activation per sample.
Why “scan inside vmap” not “vmap inside scan”
Both work. The difference is HLO shape:
- scan inside vmap: per-sample loops fused into one batched kernel. Each scan iter sees a batched matmul. This is what you want for normal training.
- vmap inside scan: each layer does a vmap over the batch. Functionally identical for this problem; differs subtly for patterns where you’d want different vmap axes per layer (rare).
Idiomatic JAX puts vmap on the outside.
Building the stacked params (recap)
Same trick from earlier problems:
@nnx.split_rngs(splits=N)
@nnx.vmap(in_axes=(0,), out_axes=0)
def make_block(rngs):
return nnx.Linear(F, F, rngs=rngs)
blocks = make_block(nnx.Rngs(seed))
blocks.kernel.value.shape == (N, F, F) after this.
Common pitfalls
-
Putting vmap inside scan accidentally.
@nnx.scan(...)outside@nnx.vmap(...)instead of the other way around. The result will run, but on most workloads the layout is suboptimal. -
Letting
in_axesget mangled.(nnx.Carry, 0)for the scan,(0, None)for the vmap. Mixing them up gives shape errors. -
Calling
batched(x, blocks)withxof shape(F,)instead of(B, F). vmap requires the leading axis to exist.
Problem
Implement composed_transforms(seed, x_batch, num_layers, features):
-
Build stacked
blocks = make_block(nnx.Rngs(int(seed)))withint(num_layers)blocks ofnnx.Linear(F, F). -
Define
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)def scan_layers(x, block): return jax.nn.relu(block(x)). -
Define
@nnx.vmap(in_axes=(0, None))def batched(x, blocks): return scan_layers(x, blocks). -
Call
batched(x_batch, blocks)to get(B, F). -
Return
out.reshape(-1).
Inputs:
-
seed: float (cast to int). -
x_batch: 2-D(B, F). -
num_layers: float (cast to int) — N. -
features: float (cast to int) — F.
Output: 1-D (B * F,) — flattened batched output.
Hints
Sign in to attempt this problem and view the solution.