We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Scan over Layer Stack
Why this matters
Modern architectures (deep MLPs, ResNets, transformer layers) often repeat the same block N times. Writing a Python for-loop over layers unrolls the computation graph at trace time, making JIT compilation slow and memory usage proportional to N. Scanning over a stacked parameter tensor instead produces a constant-size XLA program regardless of depth.
This pattern is fundamental to JAX-native model definitions: stack all
layer parameters along axis 0, then lax.scan over them. Flax and
Equinox both expose this idiom under the hood.
Worked mini-example
import jax
import jax.numpy as jnp
from jax import lax
# Two 2Γ2 weight matrices stacked β shape (2, 2, 2)
weights = jnp.array([[[1., 0.], [0., 1.]], # identity
[[0., 1.], [1., 0.]]]) # swap
def layer(y, w):
return w @ y, None # carry=new activation; output=None (unused)
final, _ = lax.scan(layer, jnp.array([3., 4.]), weights)
# Step 0: identity @ [3,4] β [3,4]
# Step 1: swap @ [3,4] β [4,3]
# final β [4.0, 3.0]
The xs argument to lax.scan is the stacked-weights tensor; scan
slices it along axis 0 and passes each slice as w to the body.
The carry is the running activation y.
Common pitfalls
-
Returning a scalar instead of a tuple from body. Even when you
donβt need per-step outputs, you must return
(new_carry, None)β not justnew_carry. -
Wrong argument order in body.
lax.scancallsbody(carry, x_i). Here carry isyandx_iiswβ donβt swap them. -
Confusing xs axis.
lax.scanalways scansxsalong axis 0. Your weight stack must have layers along axis 0 (weights[i]= layer i).
Problem
Implement scan_apply_layers(weights, x) that applies N sequential
dense layers (no activation) using lax.scan. The body is
new_y = w @ y.
-
weights: 3-D jax array of shape(N, d, d). -
x: 1-D jax array of shape(d,). -
Returns: 1-D array
(d,)β the output after all N layers.
Do not use Python loops.
Hints
Sign in to attempt this problem and view the solution.