medium primitives

Batched init via jax.vmap

Why this matters

Sometimes you need MANY copies of the same model with DIFFERENT initializations:

  • Hyperparameter / seed sweeps — train K models with different init RNGs and pick the winner.
  • Monte Carlo / deep ensembles — average predictions across K independently-initialized networks for calibrated uncertainty.
  • Bayesian deep learning — variational families parametrized by ensembles of weights.

Initializing K models with for i in range(K): params_i = model.init(...) is a Python loop that JAX can’t trace. The JAX-native way: use jax.vmap over model.init with K different RNG keys.

The pattern

model = nn.Dense(features=F)
keys = jax.random.split(jax.random.PRNGKey(seed), K)  # (K, 2)
x = jnp.ones((in_dim,))
vmapped_init = jax.vmap(model.init, in_axes=(0, None))
stacked = vmapped_init(keys, x)

in_axes=(0, None) is the crucial part: the FIRST argument (the RNG key) is mapped along axis 0 — each call gets its own key. The SECOND argument (the dummy input x) is broadcast — every call sees the same input shape.

The result stacked is a param tree where every leaf has gained a leading axis of size K:

  • stacked["params"]["kernel"].shape == (K, in_dim, F)
  • stacked["params"]["bias"].shape == (K, F)

K independent Dense layers, all in one tree. To use one of them: single = jax.tree_util.tree_map(lambda p: p[i], stacked) and then model.apply(single, x_i).

Why jax.vmap, not nn.vmap?

nn.vmap is designed to lift a Module’s __call__ — the forward pass — and integrate with Flax’s variable bookkeeping at apply time. Init is just a function (rng, *args) -> params_dict; vanilla jax.vmap works fine and is the idiomatic move here. (You CAN use nn.vmap with variable_axes={"params": 0}, split_rngs={"params": True} to get the same effect at apply time, but for init only, jax.vmap is simpler.)

Common pitfalls

  • Forgetting in_axes=(0, None) — defaults to vmapping ALL args along axis 0, so JAX tries to map x along its leading axis too. If x has no leading batch axis, you get a shape error; if it does, you get K different inputs which is probably not what you want.
  • Wrong number of keysjax.random.split(key, K) returns (K, 2). Pass that whole array; vmap unstacks per-row.
  • Reusing a single key with jnp.tile — gives K identical param trees. The point of the split is independent inits.

Problem

Implement batched_init_param_count(seed, num_inits, features):

  1. Build model = nn.Dense(features=features).
  2. Split PRNGKey(seed) into num_inits keys.
  3. Vmap model.init with in_axes=(0, None) over the keys and a dummy input x = jnp.ones((4,)) (any 1-D shape suffices for shape inference).
  4. Read off the leading axis of params["params"]["kernel"] — it should equal num_inits.
  5. Return jnp.array([float(num_inits), float(features), float(leading_axis)]) — a 1-D length-3 vector confirming the shape.

The third element should equal the first; if not, your vmap didn’t produce a stacked param tree.

Inputs:

  • seed: int.
  • num_inits: int K.
  • features: int F.

Output: 1-D (3,)[num_inits, features, kernel_leading_axis].

Hints

flax jax-vmap ensembles init

Sign in to attempt this problem and view the solution.