medium primitives

Scan over Layer Stack

Why this matters

Modern architectures (deep MLPs, ResNets, transformer layers) often repeat the same block N times. Writing a Python for-loop over layers unrolls the computation graph at trace time, making JIT compilation slow and memory usage proportional to N. Scanning over a stacked parameter tensor instead produces a constant-size XLA program regardless of depth.

This pattern is fundamental to JAX-native model definitions: stack all layer parameters along axis 0, then lax.scan over them. Flax and Equinox both expose this idiom under the hood.

Worked mini-example

import jax
import jax.numpy as jnp
from jax import lax

# Two 2Γ—2 weight matrices stacked β†’ shape (2, 2, 2)
weights = jnp.array([[[1., 0.], [0., 1.]],   # identity
                      [[0., 1.], [1., 0.]]])  # swap

def layer(y, w):
    return w @ y, None   # carry=new activation; output=None (unused)

final, _ = lax.scan(layer, jnp.array([3., 4.]), weights)
# Step 0: identity @ [3,4] β†’ [3,4]
# Step 1: swap    @ [3,4] β†’ [4,3]
# final β†’ [4.0, 3.0]

The xs argument to lax.scan is the stacked-weights tensor; scan slices it along axis 0 and passes each slice as w to the body. The carry is the running activation y.

Common pitfalls

  • Returning a scalar instead of a tuple from body. Even when you don’t need per-step outputs, you must return (new_carry, None) β€” not just new_carry.
  • Wrong argument order in body. lax.scan calls body(carry, x_i). Here carry is y and x_i is w β€” don’t swap them.
  • Confusing xs axis. lax.scan always scans xs along axis 0. Your weight stack must have layers along axis 0 (weights[i] = layer i).

Problem

Implement scan_apply_layers(weights, x) that applies N sequential dense layers (no activation) using lax.scan. The body is new_y = w @ y.

  • weights: 3-D jax array of shape (N, d, d).
  • x: 1-D jax array of shape (d,).
  • Returns: 1-D array (d,) β€” the output after all N layers.

Do not use Python loops.

Hints

jax scan layers

Sign in to attempt this problem and view the solution.