hard primitives

nn.scan over layers

Why this matters

Stacking N identical-shape transformer blocks via a Python loop:

for _ in range(N):
    x = TransformerBlock()(x)

is correct but slow to compile. Each iteration adds the full block to the JAX trace; for N=96 (GPT-3 scale), trace time and binary size become real problems. Worse, you can’t easily pickle/checkpoint the per-layer params as a single stacked tensor — they live in 96 separate Dense_0, Dense_1, …, Dense_95 slots.

nn.scan over layers solves both:

  • The trace records the body ONCE; XLA unrolls (or rolls) at compile time as it sees fit. Compile is O(1) in N.
  • Params for all N layers live in a single stacked tensor with a leading axis of size N — a single kernel of shape (N, in, out) instead of N separate kernels.

The same idea is what JAX-style libraries (e.g., MaxText, T5X) use for very deep transformers.

The two key arguments

Different from scan-over-time (pos 79):

  • variable_axes={"params": 0} — params have a leading axis of size N. Each layer gets its OWN params, stored stacked. The first axis indexes the layer. NOT broadcast (which would share one set across all layers).
  • split_rngs={"params": True} — when model.init runs, it generates N independent RNG keys (one per layer) so each layer’s params get their own random init. Without this, all N layers would init from the same key and start identical.

Plus you need length=N to tell scan how many iterations.

Carry semantics

nn.scan is lax.scan under the hood, so it expects a (carry, output) return and a fixed carry shape. We pass the activation x AS the carry (it’s what flows from layer to layer). Each block returns (new_x, None)None because we don’t need a per-layer output.

For shape consistency, the block must preserve x‘s shape: a Dense(features) only does so if x.shape[-1] == features. So the input must already live in the residual width — exactly the “residual stream” pattern transformers use.

Putting it together

class Block(nn.Module):
    features: int
    @nn.compact
    def __call__(self, x, _):
        x = nn.Dense(self.features)(x)
        x = nn.relu(x)
        return x, None

ScanBlock = nn.scan(
    Block,
    variable_axes={"params": 0},
    split_rngs={"params": True},
    length=N,
)
model = ScanBlock(features=F)
params = model.init(rng, x, None)
final, _ = model.apply(params, x, None)

Inspect params["params"]["Dense_0"]["kernel"] and you’ll see shape (N, F, F) — N layers, each with an (F, F) kernel.

Common pitfalls

  • x.shape[-1] != features: the carry shape mismatches between iterations and lax.scan raises carry input and carry output must have equal types. Project to features before scanning, or just design x to already be (F,).
  • split_rngs={"params": False}: all N layers init from the same RNG key — they’re literally identical. Useless.
  • Forgetting length=N: scan can’t infer it (no per-step input axis carries it). Pass length explicitly.

Problem

Implement scan_layers_forward(seed, x, num_layers, features):

  1. Define Block(nn.Module) with field features. In @nn.compact __call__(self, x, _): Dense(features) → relu. Return (new_x, None).
  2. Build ScanBlock = nn.scan(Block, variable_axes={"params": 0}, split_rngs={"params": True}, length=num_layers).
  3. Init and apply.
  4. Return the final activation flattened.

x.shape[-1] == features is required (residual width).

Inputs:

  • seed: int.
  • x: 1-D (features,).
  • num_layers: int.
  • features: int.

Output: 1-D (features,) — output after N layers.

Hints

flax nn-scan deep-nets lifted-transforms

Sign in to attempt this problem and view the solution.