medium primitives

NNX Batched Init

Why this matters

For ensemble methods, hyperparameter sweeps, or evolutionary strategies, you often want to instantiate N independent models in parallel. The naive Python loop:

models = [nnx.Linear(D, D, rngs=nnx.Rngs(seed + i)) for i in range(N)]

works, but:

  • It’s serial: each nnx.Linear(...) call advances Python and allocates separately.
  • It returns a Python list, not a stacked pytree — you’d have to jax.tree_map(lambda *xs: jnp.stack(xs, axis=0), *models) to get something vmappable later.
  • It scales the trace time linearly in N: that’s a problem when you want N=1000 for, say, an evolutionary strategies loop.

The vmap-of-init pattern fixes all three:

@nnx.split_rngs(splits=N)
@nnx.vmap(in_axes=(0,), out_axes=0)
def make_model(rngs):
    return nnx.Linear(F, F, rngs=rngs)

ensemble = make_model(nnx.Rngs(seed))
# ensemble.kernel.value.shape == (N, F, F)
# one Module-shaped object, params already stacked

Reading the recipe inside-out

  • The body — nnx.Linear(F, F, rngs=rngs) — describes ONE model.
  • @nnx.vmap(in_axes=(0,), out_axes=0) lifts that body to a vectorized version that takes a stacked rngs and returns a stacked Module. Each lane gets a different RNG, so each lane gets different params.
  • @nnx.split_rngs(splits=N) is the boilerplate for “turn one nnx.Rngs(seed) into N split keys along axis 0”, which is what the vmap wants.

The two decorators stack: outer split_rngs runs first to produce N rngs, then inner vmap lifts the constructor over the leading axis.

Inspecting the output

ensemble is one nnx Module instance whose internals look like a Linear, but every Variable.value has gained a leading axis of size N:

print(ensemble.kernel.value.shape)   # (N, F, F)
print(ensemble.bias.value.shape)     # (N, F)

To use it for inference, you’d vmap the forward with in_axes=(0, None) (model batched, x shared) — see the nnx-vmap-over-ensemble problem.

Why jax.vmap directly is awkward here

You COULD write:

keys = jax.random.split(jax.random.key(seed), N)
@jax.vmap
def init_fn(key):
    return nnx.Linear(F, F, rngs=nnx.Rngs(key)).split(...)

but you’d have to manually split/merge and reason about which bits of the Module are nnx.Variables vs static structure. The nnx.vmap lift handles that for you.

Common pitfalls

  • Forgetting nnx.split_rngs. Then every lane shares the same RNG and you get N copies of the same model. You’ll silently get “an ensemble” with zero diversity.
  • in_axes=(None,) for the constructor. That broadcasts the same rngs across all lanes — same problem.
  • Trying to sum or average the params. The whole point is that each model is independent. If you want a single averaged init, do it AFTER, e.g. jax.tree_map(lambda x: x.mean(0), state).

Problem

Implement batched_init_count(seed, num_inits, features):

  1. Cast N = int(num_inits), F = int(features), s = int(seed).

  2. Build:

    @nnx.split_rngs(splits=N)
    @nnx.vmap(in_axes=(0,), out_axes=0)
    def make_model(rngs):
        return nnx.Linear(F, F, rngs=rngs)
    
    ensemble = make_model(nnx.Rngs(s))
  3. Read leading = ensemble.kernel.value.shape[0].

  4. Return jnp.array([N, F, leading]). The third element should equal N — this is how the test verifies the params got the expected leading axis.

Inputs:

  • seed: float (cast to int).
  • num_inits: float (cast to int) — N.
  • features: float (cast to int) — F.

Output: 1-D (3,)[N, F, leading_axis_of_kernel].

Hints

flax nnx lifted-transforms vmap init ensemble

Sign in to attempt this problem and view the solution.