medium primitives

nn.vmap with shared params

Why this matters

jax.vmap is THE way to batch a per-example function. But Flax Modules carry parameters, and jax.vmap doesn’t know how to handle those — naively vmapping a Module’s apply might create per-example params, which is almost never what you want.

nn.vmap is the lifted version. It takes a Module class, returns a new Module class whose apply is internally vmapped — and lets you specify, per variable collection, whether params should be SHARED across the batch or HAVE A PER-EXAMPLE AXIS.

For a normal layer like nn.Dense, you want shared params: the same weight matrix applied to every example in the batch (otherwise you have B independent layers, one per example, which is an ensemble — see pos 86).

The canonical incantation

VmappedDense = nn.vmap(
    nn.Dense,
    in_axes=0,
    out_axes=0,
    variable_axes={"params": None},
    split_rngs={"params": False},
)
model = VmappedDense(features=F)
params = model.init(rng, x_batch)   # x_batch shape (B, in)
out = model.apply(params, x_batch)  # out shape (B, F)

The crucial knob: variable_axes={"params": None}. The None means “no batch axis on params” — i.e., shared. If you wrote 0, you’d get a leading batch axis on every weight, B copies of the layer.

split_rngs={"params": False} means the param-init RNG is shared too. With True, you’d get B different init keys, one per batch example — again, ensemble territory.

Why bother with nn.vmap over plain jax.vmap?

For a single Dense applied to a batch you don’t even need vmap: nn.Dense already broadcasts over leading axes. The lifted nn.vmap shines when:

  • You’re vmapping a more complex Module that DOESN’T natively handle batched inputs (e.g., a custom recurrent cell that assumes a 1-D input).
  • You want to vmap WHILE the rest of your tree handles params / RNGs / batch_stats correctly. nn.vmap integrates with all Flax variable collections.
  • You want explicit control: “params are shared, dropout RNG is split per-example,” achievable in a single declaration.

Shape semantics

  • in_axes=0 — the batch axis is leading on the inputs. x_batch shape (B, in_dim).
  • out_axes=0 — the batch axis is leading on the output.
  • With variable_axes={"params": None}: weights are unbatched (shape (in_dim, F) and (F,)).
  • Output: (B, F).

Common pitfalls

  • variable_axes={"params": 0} — accidentally creates an ensemble of B layers, each operating on its own example. Output shape is the same (B, F), but params now have a leading B axis you have to manage.
  • split_rngs={"params": True} — gives every example its own init key. Combined with variable_axes={"params": 0}, this is how you’d build an ensemble (and is exactly the wrong default for a normal Dense layer).
  • Wrong in_axes — if your batch is on axis 1 instead of 0, vmap iterates the wrong dimension. Be deliberate.

Problem

Implement nn_vmap_forward(seed, x_batch, features):

  1. Build VmappedDense = nn.vmap(nn.Dense, in_axes=0, out_axes=0, variable_axes={"params": None}, split_rngs={"params": False}).
  2. Instantiate model = VmappedDense(features=features).
  3. Init with PRNGKey(seed) and x_batch.
  4. Apply.
  5. Return the output flattened.

Inputs:

  • seed: int.
  • x_batch: 2-D (B, in_dim).
  • features: int F.

Output: 1-D — flattened (B * F,).

Hints

flax nn-vmap lifted-transforms batching

Sign in to attempt this problem and view the solution.