hard primitives

NNX vmap over Ensemble

Why this matters

“Deep ensembles” — train N independent models and average their predictions — are a strong baseline for both accuracy AND uncertainty estimation. Beating them with fancier methods (Bayesian deep learning, SWAG, etc.) is hard. The catch: you have to train N models. If “training” means “wait an hour per model”, you can’t afford 30 of them.

The trick: vmap over the model. Stack N models’ params along a new leading axis, vmap a single forward over that axis, and you get N forward passes in ONE accelerator call — at the cost of one model’s worth of compute, parallelized across the SIMD lanes of the GPU.

The previous problem used in_axes=(None, 0): shared model, batched inputs. This one flips it: in_axes=(0, None)batched models, shared input. Each of the N models sees the same x and produces its own output.

Building a stacked ensemble

The cleanest way in nnx is to combine nnx.split_rngs (split one RNG into N) with nnx.vmap over the constructor:

@nnx.split_rngs(splits=N)
@nnx.vmap(in_axes=(0,), out_axes=0)
def make_model(rngs):
    return nnx.Linear(D_in, D_out, rngs=rngs)

rngs = nnx.Rngs(seed)
ensemble = make_model(rngs)        # one Module-shaped object
                                   # whose params have leading axis N

Reading bottom-up:

  • nnx.Linear(...) is the per-model constructor — written for ONE model.
  • @nnx.vmap(in_axes=(0,), out_axes=0) vmaps that constructor over the leading RNG axis: each lane gets a different rngs, hence different params.
  • @nnx.split_rngs(splits=N) is the helper that turns a single nnx.Rngs(seed) into N split keys along axis 0, ready for the vmap above.

The result is a single Python object whose internals look like an nnx.Linear but whose .kernel.value has shape (N, D_in, D_out) instead of (D_in, D_out). Same for .bias.value.

Forward pass: in_axes=(0, None)

@nnx.vmap(in_axes=(0, None))
def fwd(model, x):
    return model(x)               # written for ONE model, ONE sample

out = fwd(ensemble, x)            # x: (D_in,) -> out: (N, D_out)

0 for the model: vmap iterates over the leading axis of every Variable inside the model. None for x: same x to every lane. Output gets a leading N axis — one output per model.

Comparison: shared vs batched in_axes

pattern model in_axes x in_axes meaning
(None, 0) (prev problem) shared batched 0 one model, many inputs
(0, None) (this problem) batched 0 shared many models, one input
(0, 0) (less common) batched 0 batched 0 per-input model (e.g. MoE)

All three are common; pick by which axis you want lanes to live on.

Common pitfalls

  • Forgetting nnx.split_rngs — without it, every lane gets the same RNG, hence identical params. The “ensemble” is N copies of the same model.
  • Calling nnx.Linear(...) N times in a Python loop and stacking manually. Works, but loses parallel init: each Linear allocs separately, and you build a Python loop’s worth of trace overhead.
  • Confusing in_axes=(0, None) with (None, 0). The first arg is the MODEL; the order matters.

Problem

Implement vmap_over_ensemble(seed, x, num_models, features):

  1. Cast N = int(num_models), F = int(features).

  2. Build the stacked ensemble:

    @nnx.split_rngs(splits=N)
    @nnx.vmap(in_axes=(0,), out_axes=0)
    def make_model(rngs):
        return nnx.Linear(x.shape[-1], F, rngs=rngs)
    ensemble = make_model(nnx.Rngs(int(seed)))
  3. Define @nnx.vmap(in_axes=(0, None)) def fwd(model, x): return model(x).

  4. Call fwd(ensemble, x) to get (N, F).

  5. Return out.reshape(-1) — a 1-D (N * F,) array.

Inputs:

  • seed: float (cast to int).
  • x: 1-D (D_in,) — one sample, broadcast to all models.
  • num_models: float (cast to int) — N.
  • features: float (cast to int) — D_out.

Output: 1-D (N * features,) — flattened ensemble output.

Hints

flax nnx lifted-transforms vmap ensemble

Sign in to attempt this problem and view the solution.