We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Batched Init
Why this matters
For ensemble methods, hyperparameter sweeps, or evolutionary strategies, you often want to instantiate N independent models in parallel. The naive Python loop:
models = [nnx.Linear(D, D, rngs=nnx.Rngs(seed + i)) for i in range(N)]
works, but:
-
It’s serial: each
nnx.Linear(...)call advances Python and allocates separately. -
It returns a Python list, not a stacked pytree — you’d have to
jax.tree_map(lambda *xs: jnp.stack(xs, axis=0), *models)to get something vmappable later. - It scales the trace time linearly in N: that’s a problem when you want N=1000 for, say, an evolutionary strategies loop.
The vmap-of-init pattern fixes all three:
@nnx.split_rngs(splits=N)
@nnx.vmap(in_axes=(0,), out_axes=0)
def make_model(rngs):
return nnx.Linear(F, F, rngs=rngs)
ensemble = make_model(nnx.Rngs(seed))
# ensemble.kernel.value.shape == (N, F, F)
# one Module-shaped object, params already stacked
Reading the recipe inside-out
-
The body —
nnx.Linear(F, F, rngs=rngs)— describes ONE model. -
@nnx.vmap(in_axes=(0,), out_axes=0)lifts that body to a vectorized version that takes a stackedrngsand returns a stacked Module. Each lane gets a different RNG, so each lane gets different params. -
@nnx.split_rngs(splits=N)is the boilerplate for “turn onennx.Rngs(seed)into N split keys along axis 0”, which is what the vmap wants.
The two decorators stack: outer split_rngs runs first to produce N
rngs, then inner vmap lifts the constructor over the leading axis.
Inspecting the output
ensemble is one nnx Module instance whose internals look like a
Linear, but every Variable.value has gained a leading axis of
size N:
print(ensemble.kernel.value.shape) # (N, F, F)
print(ensemble.bias.value.shape) # (N, F)
To use it for inference, you’d vmap the forward with
in_axes=(0, None) (model batched, x shared) — see the
nnx-vmap-over-ensemble problem.
Why jax.vmap directly is awkward here
You COULD write:
keys = jax.random.split(jax.random.key(seed), N)
@jax.vmap
def init_fn(key):
return nnx.Linear(F, F, rngs=nnx.Rngs(key)).split(...)
but you’d have to manually split/merge and reason about which
bits of the Module are nnx.Variables vs static structure. The
nnx.vmap lift handles that for you.
Common pitfalls
-
Forgetting
nnx.split_rngs. Then every lane shares the same RNG and you get N copies of the same model. You’ll silently get “an ensemble” with zero diversity. -
in_axes=(None,)for the constructor. That broadcasts the same rngs across all lanes — same problem. -
Trying to sum or average the params. The whole point is that
each model is independent. If you want a single averaged init,
do it AFTER, e.g.
jax.tree_map(lambda x: x.mean(0), state).
Problem
Implement batched_init_count(seed, num_inits, features):
-
Cast
N = int(num_inits),F = int(features),s = int(seed). -
Build:
@nnx.split_rngs(splits=N) @nnx.vmap(in_axes=(0,), out_axes=0) def make_model(rngs): return nnx.Linear(F, F, rngs=rngs) ensemble = make_model(nnx.Rngs(s)) -
Read
leading = ensemble.kernel.value.shape[0]. -
Return
jnp.array([N, F, leading]). The third element should equalN— this is how the test verifies the params got the expected leading axis.
Inputs:
-
seed: float (cast to int). -
num_inits: float (cast to int) — N. -
features: float (cast to int) — F.
Output: 1-D (3,) — [N, F, leading_axis_of_kernel].
Hints
Sign in to attempt this problem and view the solution.