We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
vmap with None broadcasting
Why this matters
In many ML computations a single parameter tensor (e.g. a weight
matrix W) is shared across an entire batch โ every example uses the
same W. When you vectorize with jax.vmap, you must tell it which
arguments are batched and which are shared. in_axes=None for a
position means โdo not slice this argument โ broadcast it as-is to
every call of the inner function.โ
This is the pattern behind vmap(grad(loss), in_axes=(None, None, 0, 0))
for per-example gradients, where the model parameters are shared and
only the data is batched.
Worked mini-example
import jax, jax.numpy as jnp
W = jnp.array([[1.0, 0.0], [0.0, 1.0]]) # shape (2, 2) โ shared
x_batch = jnp.array([[1.0, 2.0], [3.0, 4.0]]) # shape (2, 2) โ batched
result = jax.vmap(lambda W, x: W @ x, in_axes=(None, 0))(W, x_batch)
# result โ [[1.0, 2.0], [3.0, 4.0]] shape (2, 2)
# Same W applied to each x separately.
Common pitfalls
-
Forgetting
Noneand using(0, 0)โ vmap would slice W along axis 0, effectively applying rowiof W to examplei. Completely wrong for a shared matrix. -
Confusing N and d_in.
Wshape is(d_out, d_in),xshape is(d_in,)for a single example. The batch axis 0 ofx_batchis N, not d_in. -
Forgetting that
Noneapplies to the whole argument. You cannot useNoneto broadcast along one axis while batching along another within the same tensor โ use regular broadcasting for that.
Problem
Implement affine_per_example(W, x_batch) that applies the shared
weight matrix W to each example in the batch via jax.vmap.
-
W: 2-D jax array of shape(d_out, d_in)โ shared across all N. -
x_batch: 2-D jax array of shape(N, d_in). -
Returns: 2-D array of shape
(N, d_out)whereout[i] = W @ x_batch[i].
Use jax.vmap with in_axes=(None, 0).
Hints
Sign in to attempt this problem and view the solution.