hard primitives

Custom lift: roll your own ensemble

Why this matters

Flax lifts (nn.scan, nn.vmap, nn.remat, …) are built on top of plain JAX transforms (lax.scan, jax.vmap, jax.checkpoint) plus some Flax-internal bookkeeping for variable collections, RNGs, and Module structure.

Most of the time you should reach for the lifts — they handle the nuances correctly. But understanding what’s underneath helps:

  • Debugging weird lift errors — knowing how a lift maps to the underlying JAX primitive lets you bisect “is this Flax’s bug or JAX’s?”.
  • Cases the lifts don’t fit — sometimes you want partial axis mapping, irregular shapes, or non-standard variable handling that nn.vmap doesn’t expose. Roll your own.
  • Performance instinct — when a lift is slow, knowing the underlying transform helps you reason about why.

This problem rolls a tiny “ensemble lift” by hand: K independent copies of a Module, each with its own params, applied to the same input, with outputs averaged.

The recipe

  1. Init K param trees in paralleljax.vmap(model.init, in_axes=(0, None)) with K split keys (this is exactly pos 86). Result: a stacked param tree, every leaf has leading axis K.

  2. Apply each copy — Python-loop over i in range(K), slice out the i-th param tree with tree_map(lambda p: p[i], stacked), call model.apply(single, x), accumulate outputs.

  3. Reduce — divide by K to get the ensemble mean.

keys = jax.random.split(jax.random.PRNGKey(seed), K)
params_stacked = jax.vmap(model.init, in_axes=(0, None))(keys, x)
out = jnp.zeros_like(model.apply(jax.tree_util.tree_map(
    lambda p: p[0], params_stacked), x))
for i in range(K):
    single = jax.tree_util.tree_map(lambda p: p[i], params_stacked)
    out = out + model.apply(single, x)
return out / K

Compare to nn.vmap with variable_axes={"params": 0}, split_rngs={"params": True} — that’s basically the same thing, but Flax does the slicing under the hood and you get a single batched apply call (faster, JAX-native).

Why not vectorize the apply?

You COULD wrap the apply in another jax.vmap to make all K forward passes happen in parallel. The Python-loop version above is intentionally pedagogical — it makes the slicing explicit. For production, jax.vmap(model.apply, in_axes=(0, None))(stacked, x) gives you both batched init AND batched apply with no Python loop.

Lambda capture gotcha

lambda p: p[i] captures i by closure, not by value. Inside a Python for i in range(K) loop, all the closures share the same i. After the loop, every closure sees i = K-1.

For the iterative loop pattern, this matters less because you call each closure inside the loop body before i advances. But in JAX-traced contexts (lambda passed to tree_map and stored), bind the loop variable as a default arg: lambda p, i=i: p[i].

Common pitfalls

  • Reusing one RNG key — gives K identical param trees, not an ensemble.
  • Forgetting in_axes=(0, None) — defaults map all args along axis 0. The dummy input x would be mapped along its leading axis, breaking Dense’s shape inference.
  • Indexing with the wrong shapeparams_stacked["params"] ["kernel"] has shape (K, in_dim, F). [i] gives (in_dim, F), what model.apply expects.
  • Forgetting to divide by K — gives the SUM of K outputs, not the mean. Either is a valid ensemble, but the reference uses the mean.

Problem

Implement custom_lift_forward(seed, x, n_copies, features):

  1. Build model = nn.Dense(features=features).
  2. Split PRNGKey(seed) into n_copies keys.
  3. Init in parallel: params_stacked = jax.vmap(model.init, in_axes=(0, None))(keys, x).
  4. Loop i in range(n_copies): slice the i-th params tree with jax.tree_util.tree_map(lambda p, i=i: p[i], params_stacked), apply, accumulate the output.
  5. Return the accumulated sum divided by n_copies (ensemble mean).

Inputs:

  • seed: int.
  • x: 1-D — input to all K copies.
  • n_copies: int K.
  • features: int F.

Output: 1-D (F,) — mean of K Dense forward passes on x.

Hints

flax custom-lifts jax-vmap ensembles

Sign in to attempt this problem and view the solution.