We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Batched init via jax.vmap
Why this matters
Sometimes you need MANY copies of the same model with DIFFERENT initializations:
- Hyperparameter / seed sweeps — train K models with different init RNGs and pick the winner.
- Monte Carlo / deep ensembles — average predictions across K independently-initialized networks for calibrated uncertainty.
- Bayesian deep learning — variational families parametrized by ensembles of weights.
Initializing K models with for i in range(K): params_i = model.init(...)
is a Python loop that JAX can’t trace. The JAX-native way: use
jax.vmap over model.init with K different RNG keys.
The pattern
model = nn.Dense(features=F)
keys = jax.random.split(jax.random.PRNGKey(seed), K) # (K, 2)
x = jnp.ones((in_dim,))
vmapped_init = jax.vmap(model.init, in_axes=(0, None))
stacked = vmapped_init(keys, x)
in_axes=(0, None) is the crucial part: the FIRST argument (the
RNG key) is mapped along axis 0 — each call gets its own key. The
SECOND argument (the dummy input x) is broadcast — every call
sees the same input shape.
The result stacked is a param tree where every leaf has gained
a leading axis of size K:
-
stacked["params"]["kernel"].shape == (K, in_dim, F) -
stacked["params"]["bias"].shape == (K, F)
K independent Dense layers, all in one tree. To use one of them:
single = jax.tree_util.tree_map(lambda p: p[i], stacked) and then
model.apply(single, x_i).
Why jax.vmap, not nn.vmap?
nn.vmap is designed to lift a Module’s __call__ — the forward
pass — and integrate with Flax’s variable bookkeeping at apply time.
Init is just a function (rng, *args) -> params_dict; vanilla
jax.vmap works fine and is the idiomatic move here. (You CAN use
nn.vmap with variable_axes={"params": 0}, split_rngs={"params": True}
to get the same effect at apply time, but for init only, jax.vmap
is simpler.)
Common pitfalls
-
Forgetting
in_axes=(0, None)— defaults to vmapping ALL args along axis 0, so JAX tries to mapxalong its leading axis too. Ifxhas no leading batch axis, you get a shape error; if it does, you get K different inputs which is probably not what you want. -
Wrong number of keys —
jax.random.split(key, K)returns(K, 2). Pass that whole array; vmap unstacks per-row. -
Reusing a single key with
jnp.tile— gives K identical param trees. The point of the split is independent inits.
Problem
Implement batched_init_param_count(seed, num_inits, features):
-
Build
model = nn.Dense(features=features). -
Split
PRNGKey(seed)intonum_initskeys. -
Vmap
model.initwithin_axes=(0, None)over the keys and a dummy inputx = jnp.ones((4,))(any 1-D shape suffices for shape inference). -
Read off the leading axis of
params["params"]["kernel"]— it should equalnum_inits. -
Return
jnp.array([float(num_inits), float(features), float(leading_axis)])— a 1-D length-3 vector confirming the shape.
The third element should equal the first; if not, your vmap didn’t produce a stacked param tree.
Inputs:
-
seed: int. -
num_inits: int K. -
features: int F.
Output: 1-D (3,) — [num_inits, features, kernel_leading_axis].
Hints
Sign in to attempt this problem and view the solution.