medium primitives

NNX vmap over Batch

Why this matters

A linear layer’s matmul is happy with batches: x @ W works whether x is (D,) or (B, D). So in this trivial case, nnx.vmap is overkill — model(x_batch) already does the right thing. Why bother?

Because most modules don’t broadcast cleanly. Custom blocks with Python control flow, layers that index along a specific axis, attention with positional encodings: when the per-example logic is written assuming a single sample, you’d otherwise rewrite it to handle a batch dim. Vmap lifts the per-sample function to a per-batch function automatically, without rewriting.

nnx.vmap is the lifted version of jax.vmap for nnx Modules. It does the right split/merge under the hood so your function can take a Module as a regular argument.

The canonical incantation: shared params, batched inputs

model = nnx.Linear(D_in, D_out, rngs=nnx.Rngs(seed))

@nnx.vmap(in_axes=(None, 0))     # model shared, x batched on axis 0
def fwd(model, x):
    return model(x)              # written for ONE sample (D_in,)

out = fwd(model, x_batch)        # x_batch: (B, D_in) -> (B, D_out)

in_axes=(None, 0) says:

  • Argument 0 (model) is not vmapped over — its parameters are shared across all batch elements. This is the standard “one model, many inputs” pattern.
  • Argument 1 (x) is vmapped over axis 0. Each batch element sees its own (D_in,) slice.

Why None for the model

Passing None for the model’s in_axes is what makes vmap share the params across the batch. If you used 0 instead, vmap would expect a model whose params have a leading axis of size B — i.e., one model per batch element (an ensemble; that’s the next problem). For training a single model on a minibatch, you want None.

Conceptually, nnx.vmap(in_axes=(None, 0)) does this:

graphdef, state = nnx.split(model)
@jax.vmap(in_axes=(None, 0))
def lifted(state, x):
    m = nnx.merge(graphdef, state)
    return m(x)
return lifted(state, x_batch)

State is shared (axis None), x is mapped (axis 0). Inside the lifted body, each batch slice gets a fresh merge’d model that all point at the SAME parameter arrays — no copies.

Common pitfalls

  • Forgetting in_axes=(None, 0) — the default in_axes=0 would try to vmap over BOTH the model and x, which fails because the model isn’t batched.
  • Passing a 1-D x instead of a 2-D batch — vmap over axis 0 requires that axis to exist. Add a leading dim with x[None, :] if you have a single sample.
  • Reusing this for ensembles — for “one model per batch element”, see the next problem (nnx-vmap-over-ensemble). The in_axes is (0, None) there: state batched, x shared.

Problem

Implement vmap_over_batch(seed, x_batch, features):

  1. Build model = nnx.Linear(x_batch.shape[-1], int(features), rngs=nnx.Rngs(int(seed))).
  2. Define @nnx.vmap(in_axes=(None, 0)) def fwd(model, x): return model(x). The function is written as if x is a single sample of shape (D_in,).
  3. Call fwd(model, x_batch) to get (B, features).
  4. Return out.reshape(-1) — a 1-D (B * features,) array.

Inputs:

  • seed: float (cast to int).
  • x_batch: 2-D (B, D_in).
  • features: float (cast to int) — D_out.

Output: 1-D (B * features,) — flattened batched output.

Hints

flax nnx lifted-transforms vmap

Sign in to attempt this problem and view the solution.