We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
nn.vmap with shared params
Why this matters
jax.vmap is THE way to batch a per-example function. But Flax
Modules carry parameters, and jax.vmap doesn’t know how to
handle those — naively vmapping a Module’s apply might create
per-example params, which is almost never what you want.
nn.vmap is the lifted version. It takes a Module class, returns
a new Module class whose apply is internally vmapped — and lets
you specify, per variable collection, whether params should be
SHARED across the batch or HAVE A PER-EXAMPLE AXIS.
For a normal layer like nn.Dense, you want shared params:
the same weight matrix applied to every example in the batch
(otherwise you have B independent layers, one per example, which
is an ensemble — see pos 86).
The canonical incantation
VmappedDense = nn.vmap(
nn.Dense,
in_axes=0,
out_axes=0,
variable_axes={"params": None},
split_rngs={"params": False},
)
model = VmappedDense(features=F)
params = model.init(rng, x_batch) # x_batch shape (B, in)
out = model.apply(params, x_batch) # out shape (B, F)
The crucial knob: variable_axes={"params": None}. The None
means “no batch axis on params” — i.e., shared. If you wrote 0,
you’d get a leading batch axis on every weight, B copies of the
layer.
split_rngs={"params": False} means the param-init RNG is
shared too. With True, you’d get B different init keys, one per
batch example — again, ensemble territory.
Why bother with nn.vmap over plain jax.vmap?
For a single Dense applied to a batch you don’t even need vmap:
nn.Dense already broadcasts over leading axes. The lifted
nn.vmap shines when:
- You’re vmapping a more complex Module that DOESN’T natively handle batched inputs (e.g., a custom recurrent cell that assumes a 1-D input).
-
You want to vmap WHILE the rest of your tree handles params /
RNGs / batch_stats correctly.
nn.vmapintegrates with all Flax variable collections. - You want explicit control: “params are shared, dropout RNG is split per-example,” achievable in a single declaration.
Shape semantics
-
in_axes=0— the batch axis is leading on the inputs.x_batchshape(B, in_dim). -
out_axes=0— the batch axis is leading on the output. -
With
variable_axes={"params": None}: weights are unbatched (shape(in_dim, F)and(F,)). -
Output:
(B, F).
Common pitfalls
-
variable_axes={"params": 0}— accidentally creates an ensemble of B layers, each operating on its own example. Output shape is the same(B, F), but params now have a leading B axis you have to manage. -
split_rngs={"params": True}— gives every example its own init key. Combined withvariable_axes={"params": 0}, this is how you’d build an ensemble (and is exactly the wrong default for a normal Dense layer). -
Wrong
in_axes— if your batch is on axis 1 instead of 0, vmap iterates the wrong dimension. Be deliberate.
Problem
Implement nn_vmap_forward(seed, x_batch, features):
-
Build
VmappedDense = nn.vmap(nn.Dense, in_axes=0, out_axes=0, variable_axes={"params": None}, split_rngs={"params": False}). -
Instantiate
model = VmappedDense(features=features). -
Init with
PRNGKey(seed)andx_batch. - Apply.
- Return the output flattened.
Inputs:
-
seed: int. -
x_batch: 2-D(B, in_dim). -
features: int F.
Output: 1-D — flattened (B * F,).
Hints
Sign in to attempt this problem and view the solution.