We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Custom lift: roll your own ensemble
Why this matters
Flax lifts (nn.scan, nn.vmap, nn.remat, …) are built on top
of plain JAX transforms (lax.scan, jax.vmap, jax.checkpoint)
plus some Flax-internal bookkeeping for variable collections, RNGs,
and Module structure.
Most of the time you should reach for the lifts — they handle the nuances correctly. But understanding what’s underneath helps:
- Debugging weird lift errors — knowing how a lift maps to the underlying JAX primitive lets you bisect “is this Flax’s bug or JAX’s?”.
-
Cases the lifts don’t fit — sometimes you want partial axis
mapping, irregular shapes, or non-standard variable handling that
nn.vmapdoesn’t expose. Roll your own. - Performance instinct — when a lift is slow, knowing the underlying transform helps you reason about why.
This problem rolls a tiny “ensemble lift” by hand: K independent copies of a Module, each with its own params, applied to the same input, with outputs averaged.
The recipe
-
Init K param trees in parallel —
jax.vmap(model.init, in_axes=(0, None))with K split keys (this is exactly pos 86). Result: a stacked param tree, every leaf has leading axis K. -
Apply each copy — Python-loop over
i in range(K), slice out the i-th param tree withtree_map(lambda p: p[i], stacked), callmodel.apply(single, x), accumulate outputs. -
Reduce — divide by K to get the ensemble mean.
keys = jax.random.split(jax.random.PRNGKey(seed), K)
params_stacked = jax.vmap(model.init, in_axes=(0, None))(keys, x)
out = jnp.zeros_like(model.apply(jax.tree_util.tree_map(
lambda p: p[0], params_stacked), x))
for i in range(K):
single = jax.tree_util.tree_map(lambda p: p[i], params_stacked)
out = out + model.apply(single, x)
return out / K
Compare to nn.vmap with variable_axes={"params": 0}, split_rngs={"params": True} — that’s basically the same thing, but
Flax does the slicing under the hood and you get a single batched
apply call (faster, JAX-native).
Why not vectorize the apply?
You COULD wrap the apply in another jax.vmap to make all K
forward passes happen in parallel. The Python-loop version above
is intentionally pedagogical — it makes the slicing explicit. For
production, jax.vmap(model.apply, in_axes=(0, None))(stacked, x)
gives you both batched init AND batched apply with no Python loop.
Lambda capture gotcha
lambda p: p[i] captures i by closure, not by value. Inside
a Python for i in range(K) loop, all the closures share the same
i. After the loop, every closure sees i = K-1.
For the iterative loop pattern, this matters less because you call
each closure inside the loop body before i advances. But in
JAX-traced contexts (lambda passed to tree_map and stored), bind
the loop variable as a default arg: lambda p, i=i: p[i].
Common pitfalls
- Reusing one RNG key — gives K identical param trees, not an ensemble.
-
Forgetting
in_axes=(0, None)— defaults map all args along axis 0. The dummy inputxwould be mapped along its leading axis, breaking Dense’s shape inference. -
Indexing with the wrong shape —
params_stacked["params"] ["kernel"]has shape(K, in_dim, F).[i]gives(in_dim, F), whatmodel.applyexpects. - Forgetting to divide by K — gives the SUM of K outputs, not the mean. Either is a valid ensemble, but the reference uses the mean.
Problem
Implement custom_lift_forward(seed, x, n_copies, features):
-
Build
model = nn.Dense(features=features). -
Split
PRNGKey(seed)inton_copieskeys. -
Init in parallel:
params_stacked = jax.vmap(model.init, in_axes=(0, None))(keys, x). -
Loop
i in range(n_copies): slice the i-th params tree withjax.tree_util.tree_map(lambda p, i=i: p[i], params_stacked), apply, accumulate the output. -
Return the accumulated sum divided by
n_copies(ensemble mean).
Inputs:
-
seed: int. -
x: 1-D — input to all K copies. -
n_copies: int K. -
features: int F.
Output: 1-D (F,) — mean of K Dense forward passes on x.
Hints
Sign in to attempt this problem and view the solution.