We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX vmap over Batch
Why this matters
A linear layer’s matmul is happy with batches: x @ W works whether
x is (D,) or (B, D). So in this trivial case, nnx.vmap is
overkill — model(x_batch) already does the right thing. Why bother?
Because most modules don’t broadcast cleanly. Custom blocks with Python control flow, layers that index along a specific axis, attention with positional encodings: when the per-example logic is written assuming a single sample, you’d otherwise rewrite it to handle a batch dim. Vmap lifts the per-sample function to a per-batch function automatically, without rewriting.
nnx.vmap is the lifted version of jax.vmap for nnx Modules.
It does the right split/merge under the hood so your function can
take a Module as a regular argument.
The canonical incantation: shared params, batched inputs
model = nnx.Linear(D_in, D_out, rngs=nnx.Rngs(seed))
@nnx.vmap(in_axes=(None, 0)) # model shared, x batched on axis 0
def fwd(model, x):
return model(x) # written for ONE sample (D_in,)
out = fwd(model, x_batch) # x_batch: (B, D_in) -> (B, D_out)
in_axes=(None, 0) says:
-
Argument 0 (
model) is not vmapped over — its parameters are shared across all batch elements. This is the standard “one model, many inputs” pattern. -
Argument 1 (
x) is vmapped over axis 0. Each batch element sees its own(D_in,)slice.
Why None for the model
Passing None for the model’s in_axes is what makes vmap
share the params across the batch. If you used 0 instead,
vmap would expect a model whose params have a leading axis of size
B — i.e., one model per batch element (an ensemble; that’s the
next problem). For training a single model on a minibatch, you
want None.
Conceptually, nnx.vmap(in_axes=(None, 0)) does this:
graphdef, state = nnx.split(model)
@jax.vmap(in_axes=(None, 0))
def lifted(state, x):
m = nnx.merge(graphdef, state)
return m(x)
return lifted(state, x_batch)
State is shared (axis None), x is mapped (axis 0). Inside the
lifted body, each batch slice gets a fresh merge’d model that all
point at the SAME parameter arrays — no copies.
Common pitfalls
-
Forgetting
in_axes=(None, 0)— the defaultin_axes=0would try to vmap over BOTH the model and x, which fails because the model isn’t batched. -
Passing a 1-D
xinstead of a 2-D batch — vmap over axis 0 requires that axis to exist. Add a leading dim withx[None, :]if you have a single sample. -
Reusing this for ensembles — for “one model per batch element”,
see the next problem (
nnx-vmap-over-ensemble). Thein_axesis(0, None)there: state batched, x shared.
Problem
Implement vmap_over_batch(seed, x_batch, features):
-
Build
model = nnx.Linear(x_batch.shape[-1], int(features), rngs=nnx.Rngs(int(seed))). -
Define
@nnx.vmap(in_axes=(None, 0))def fwd(model, x): return model(x). The function is written as ifxis a single sample of shape(D_in,). -
Call
fwd(model, x_batch)to get(B, features). -
Return
out.reshape(-1)— a 1-D(B * features,)array.
Inputs:
-
seed: float (cast to int). -
x_batch: 2-D(B, D_in). -
features: float (cast to int) —D_out.
Output: 1-D (B * features,) — flattened batched output.
Hints
Sign in to attempt this problem and view the solution.