We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX vmap over Ensemble
Why this matters
“Deep ensembles” — train N independent models and average their predictions — are a strong baseline for both accuracy AND uncertainty estimation. Beating them with fancier methods (Bayesian deep learning, SWAG, etc.) is hard. The catch: you have to train N models. If “training” means “wait an hour per model”, you can’t afford 30 of them.
The trick: vmap over the model. Stack N models’ params along a new leading axis, vmap a single forward over that axis, and you get N forward passes in ONE accelerator call — at the cost of one model’s worth of compute, parallelized across the SIMD lanes of the GPU.
The previous problem used in_axes=(None, 0): shared model, batched
inputs. This one flips it: in_axes=(0, None) — batched models,
shared input. Each of the N models sees the same x and produces
its own output.
Building a stacked ensemble
The cleanest way in nnx is to combine nnx.split_rngs (split one
RNG into N) with nnx.vmap over the constructor:
@nnx.split_rngs(splits=N)
@nnx.vmap(in_axes=(0,), out_axes=0)
def make_model(rngs):
return nnx.Linear(D_in, D_out, rngs=rngs)
rngs = nnx.Rngs(seed)
ensemble = make_model(rngs) # one Module-shaped object
# whose params have leading axis N
Reading bottom-up:
-
nnx.Linear(...)is the per-model constructor — written for ONE model. -
@nnx.vmap(in_axes=(0,), out_axes=0)vmaps that constructor over the leading RNG axis: each lane gets a differentrngs, hence different params. -
@nnx.split_rngs(splits=N)is the helper that turns a singlennx.Rngs(seed)into N split keys along axis 0, ready for the vmap above.
The result is a single Python object whose internals look like an
nnx.Linear but whose .kernel.value has shape (N, D_in, D_out)
instead of (D_in, D_out). Same for .bias.value.
Forward pass: in_axes=(0, None)
@nnx.vmap(in_axes=(0, None))
def fwd(model, x):
return model(x) # written for ONE model, ONE sample
out = fwd(ensemble, x) # x: (D_in,) -> out: (N, D_out)
0 for the model: vmap iterates over the leading axis of every
Variable inside the model. None for x: same x to every lane.
Output gets a leading N axis — one output per model.
Comparison: shared vs batched in_axes
| pattern | model in_axes | x in_axes | meaning |
|---|---|---|---|
(None, 0) (prev problem) |
shared | batched 0 | one model, many inputs |
(0, None) (this problem) |
batched 0 | shared | many models, one input |
(0, 0) (less common) |
batched 0 | batched 0 | per-input model (e.g. MoE) |
All three are common; pick by which axis you want lanes to live on.
Common pitfalls
-
Forgetting
nnx.split_rngs— without it, every lane gets the same RNG, hence identical params. The “ensemble” is N copies of the same model. -
Calling
nnx.Linear(...)N times in a Python loop and stacking manually. Works, but loses parallel init: each Linear allocs separately, and you build a Python loop’s worth of trace overhead. -
Confusing
in_axes=(0, None)with(None, 0). The first arg is the MODEL; the order matters.
Problem
Implement vmap_over_ensemble(seed, x, num_models, features):
-
Cast
N = int(num_models),F = int(features). -
Build the stacked ensemble:
@nnx.split_rngs(splits=N) @nnx.vmap(in_axes=(0,), out_axes=0) def make_model(rngs): return nnx.Linear(x.shape[-1], F, rngs=rngs) ensemble = make_model(nnx.Rngs(int(seed))) -
Define
@nnx.vmap(in_axes=(0, None))def fwd(model, x): return model(x). -
Call
fwd(ensemble, x)to get(N, F). -
Return
out.reshape(-1)— a 1-D(N * F,)array.
Inputs:
-
seed: float (cast to int). -
x: 1-D(D_in,)— one sample, broadcast to all models. -
num_models: float (cast to int) — N. -
features: float (cast to int) — D_out.
Output: 1-D (N * features,) — flattened ensemble output.
Hints
Sign in to attempt this problem and view the solution.