We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Composed lifts: nn.scan + nn.vmap
Why this matters
Real models combine both axes of structure: a stack of N layers (a depth axis) AND a batch of B examples (a parallel axis). With Flax lifted transforms, you get to express both in a single Module that captures the param layout you want for each:
-
Per-layer params, one set of weights for layer 0, another for
layer 1, …, layer N-1. That’s
nn.scanover layers. -
Shared params across the batch (the standard case — same
weights for every example). That’s
nn.vmapwithvariable_axes={"params": None}.
Compose them and you have the spine of a real transformer-style forward pass, expressed declaratively, with all the parameter bookkeeping handled by Flax.
How lift composition works
Lifts compose like decorators: the inner lift wraps the Module first, the outer lift wraps the result.
ScanLayer = nn.scan(Block, variable_axes={"params": 0},
split_rngs={"params": True}, length=N)
BatchedScanLayer = nn.vmap(
ScanLayer,
in_axes=0, out_axes=0,
variable_axes={"params": None},
split_rngs={"params": False},
)
model = BatchedScanLayer(features=F)
Read inside-out:
-
Start with
Block(a single-layer per-example Module). -
nn.scanlifts it to a stack-of-N-layers per-example Module whose params have a leadingNaxis. -
nn.vmaplifts THAT to a stack-of-N-layers batched Module — same N-layer params shared across B examples in the batch.
The order matters: the params live “inside” the scan axis, “outside” the vmap axis. Reverse the composition and you’d get per-example params that are then per-layer-stacked, which is structurally different (and almost never what you want for transformer-style architectures).
What about variable_axes for each lift?
The two lifts are configuring DIFFERENT collections / axes on the
same params tree. Each lift’s variable_axes declaration applies
to the slice of structure ITS lift adds:
-
nn.scanadds the leading layer axis:variable_axes={"params": 0}. -
nn.vmapdeclares whether IT (vmap) adds a batch axis on top:variable_axes={"params": None}says “no, params are shared across the batch”. So the final params tree has shape(N, ...)per kernel — N layers, no batch axis.
Likewise split_rngs:
-
nn.scanwithsplit_rngs={"params": True}— N different init keys. -
nn.vmapwithsplit_rngs={"params": False}— params init RNG is shared across the batch (one tree, replicated logically).
Carry / output semantics
The inner Block returns (new_x, None) because we use scan in
the carry-only mode. After both lifts, calling
model.apply(params, x_batch, None) returns (final, ys) where
final has the BATCHED carry shape (B, F) and ys is None.
Common pitfalls
-
Lift order reversed:
nn.scan(nn.vmap(Block, ...), ...)instead ofnn.vmap(nn.scan(Block, ...), ...). This is structurally different — the batch axis ends up INSIDE the scan, which usually isn’t what you want. -
variable_axes={"params": 0}on the vmap: gives every batch example its own copy of the N-layer stack (an ensemble). Output shape works, but you’ve now gotB * Nindependent layers. -
split_rngs={"params": True}on the vmap: pairs withvariable_axes={"params": 0}to make a real ensemble; withvariable_axes={"params": None}it’s confused (the split happens, but Flax has to broadcast back). Stick toFalsefor shared-param vmap. -
x.shape[-1] != features: the carry shape mismatches between scan iterations. Make sure the residual width matchesfeatures.
Problem
Implement composed_lifts_forward(seed, x_batch, num_layers, features):
-
Define
Block(nn.Module)with fieldfeatures.@nn.compact__call__(self, x, _):Dense(features) → relu. Return(new_x, None). -
Build
ScanLayer = nn.scan(Block, variable_axes={"params": 0}, split_rngs={"params": True}, length=num_layers). -
Build
BatchedScanLayer = nn.vmap(ScanLayer, in_axes=0, out_axes=0, variable_axes={"params": None}, split_rngs={"params": False}). -
Instantiate
model = BatchedScanLayer(features=features). -
Init and apply with
(params, x_batch, None). - Return the batched final activation flattened.
x_batch.shape[-1] == features is required.
Inputs:
-
seed: int. -
x_batch: 2-D(B, features). -
num_layers: int N. -
features: int F.
Output: 1-D — flattened (B * F,) final activation.
Hints
Sign in to attempt this problem and view the solution.