hard primitives

Composed lifts: nn.scan + nn.vmap

Why this matters

Real models combine both axes of structure: a stack of N layers (a depth axis) AND a batch of B examples (a parallel axis). With Flax lifted transforms, you get to express both in a single Module that captures the param layout you want for each:

  • Per-layer params, one set of weights for layer 0, another for layer 1, …, layer N-1. That’s nn.scan over layers.
  • Shared params across the batch (the standard case — same weights for every example). That’s nn.vmap with variable_axes={"params": None}.

Compose them and you have the spine of a real transformer-style forward pass, expressed declaratively, with all the parameter bookkeeping handled by Flax.

How lift composition works

Lifts compose like decorators: the inner lift wraps the Module first, the outer lift wraps the result.

ScanLayer = nn.scan(Block, variable_axes={"params": 0},
                    split_rngs={"params": True}, length=N)
BatchedScanLayer = nn.vmap(
    ScanLayer,
    in_axes=0, out_axes=0,
    variable_axes={"params": None},
    split_rngs={"params": False},
)
model = BatchedScanLayer(features=F)

Read inside-out:

  1. Start with Block (a single-layer per-example Module).
  2. nn.scan lifts it to a stack-of-N-layers per-example Module whose params have a leading N axis.
  3. nn.vmap lifts THAT to a stack-of-N-layers batched Module — same N-layer params shared across B examples in the batch.

The order matters: the params live “inside” the scan axis, “outside” the vmap axis. Reverse the composition and you’d get per-example params that are then per-layer-stacked, which is structurally different (and almost never what you want for transformer-style architectures).

What about variable_axes for each lift?

The two lifts are configuring DIFFERENT collections / axes on the same params tree. Each lift’s variable_axes declaration applies to the slice of structure ITS lift adds:

  • nn.scan adds the leading layer axis: variable_axes={"params": 0}.
  • nn.vmap declares whether IT (vmap) adds a batch axis on top: variable_axes={"params": None} says “no, params are shared across the batch”. So the final params tree has shape (N, ...) per kernel — N layers, no batch axis.

Likewise split_rngs:

  • nn.scan with split_rngs={"params": True} — N different init keys.
  • nn.vmap with split_rngs={"params": False} — params init RNG is shared across the batch (one tree, replicated logically).

Carry / output semantics

The inner Block returns (new_x, None) because we use scan in the carry-only mode. After both lifts, calling model.apply(params, x_batch, None) returns (final, ys) where final has the BATCHED carry shape (B, F) and ys is None.

Common pitfalls

  • Lift order reversed: nn.scan(nn.vmap(Block, ...), ...) instead of nn.vmap(nn.scan(Block, ...), ...). This is structurally different — the batch axis ends up INSIDE the scan, which usually isn’t what you want.
  • variable_axes={"params": 0} on the vmap: gives every batch example its own copy of the N-layer stack (an ensemble). Output shape works, but you’ve now got B * N independent layers.
  • split_rngs={"params": True} on the vmap: pairs with variable_axes={"params": 0} to make a real ensemble; with variable_axes={"params": None} it’s confused (the split happens, but Flax has to broadcast back). Stick to False for shared-param vmap.
  • x.shape[-1] != features: the carry shape mismatches between scan iterations. Make sure the residual width matches features.

Problem

Implement composed_lifts_forward(seed, x_batch, num_layers, features):

  1. Define Block(nn.Module) with field features. @nn.compact __call__(self, x, _): Dense(features) → relu. Return (new_x, None).
  2. Build ScanLayer = nn.scan(Block, variable_axes={"params": 0}, split_rngs={"params": True}, length=num_layers).
  3. Build BatchedScanLayer = nn.vmap(ScanLayer, in_axes=0, out_axes=0, variable_axes={"params": None}, split_rngs={"params": False}).
  4. Instantiate model = BatchedScanLayer(features=features).
  5. Init and apply with (params, x_batch, None).
  6. Return the batched final activation flattened.

x_batch.shape[-1] == features is required.

Inputs:

  • seed: int.
  • x_batch: 2-D (B, features).
  • num_layers: int N.
  • features: int F.

Output: 1-D — flattened (B * F,) final activation.

Hints

flax lifted-transforms nn-scan nn-vmap

Sign in to attempt this problem and view the solution.