We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
nn.scan over layers
Why this matters
Stacking N identical-shape transformer blocks via a Python loop:
for _ in range(N):
x = TransformerBlock()(x)
is correct but slow to compile. Each iteration adds the full block
to the JAX trace; for N=96 (GPT-3 scale), trace time and binary
size become real problems. Worse, you can’t easily pickle/checkpoint
the per-layer params as a single stacked tensor — they live in 96
separate Dense_0, Dense_1, …, Dense_95 slots.
nn.scan over layers solves both:
- The trace records the body ONCE; XLA unrolls (or rolls) at compile time as it sees fit. Compile is O(1) in N.
-
Params for all N layers live in a single stacked tensor with a
leading axis of size N — a single
kernelof shape(N, in, out)instead of N separatekernels.
The same idea is what JAX-style libraries (e.g., MaxText, T5X) use for very deep transformers.
The two key arguments
Different from scan-over-time (pos 79):
-
variable_axes={"params": 0}— params have a leading axis of size N. Each layer gets its OWN params, stored stacked. The first axis indexes the layer. NOT broadcast (which would share one set across all layers). -
split_rngs={"params": True}— whenmodel.initruns, it generates N independent RNG keys (one per layer) so each layer’s params get their own random init. Without this, all N layers would init from the same key and start identical.
Plus you need length=N to tell scan how many iterations.
Carry semantics
nn.scan is lax.scan under the hood, so it expects a (carry, output) return and a fixed carry shape. We pass the activation
x AS the carry (it’s what flows from layer to layer). Each block
returns (new_x, None) — None because we don’t need a per-layer
output.
For shape consistency, the block must preserve x‘s shape: a
Dense(features) only does so if x.shape[-1] == features. So
the input must already live in the residual width — exactly the
“residual stream” pattern transformers use.
Putting it together
class Block(nn.Module):
features: int
@nn.compact
def __call__(self, x, _):
x = nn.Dense(self.features)(x)
x = nn.relu(x)
return x, None
ScanBlock = nn.scan(
Block,
variable_axes={"params": 0},
split_rngs={"params": True},
length=N,
)
model = ScanBlock(features=F)
params = model.init(rng, x, None)
final, _ = model.apply(params, x, None)
Inspect params["params"]["Dense_0"]["kernel"] and you’ll see
shape (N, F, F) — N layers, each with an (F, F) kernel.
Common pitfalls
-
x.shape[-1] != features: the carry shape mismatches between iterations andlax.scanraisescarry input and carry output must have equal types. Project tofeaturesbefore scanning, or just designxto already be(F,). -
split_rngs={"params": False}: all N layers init from the same RNG key — they’re literally identical. Useless. -
Forgetting
length=N: scan can’t infer it (no per-step input axis carries it). Passlengthexplicitly.
Problem
Implement scan_layers_forward(seed, x, num_layers, features):
-
Define
Block(nn.Module)with fieldfeatures. In@nn.compact__call__(self, x, _):Dense(features) → relu. Return(new_x, None). -
Build
ScanBlock = nn.scan(Block, variable_axes={"params": 0}, split_rngs={"params": True}, length=num_layers). - Init and apply.
- Return the final activation flattened.
x.shape[-1] == features is required (residual width).
Inputs:
-
seed: int. -
x: 1-D(features,). -
num_layers: int. -
features: int.
Output: 1-D (features,) — output after N layers.
Hints
Sign in to attempt this problem and view the solution.