easy primitives

vmap with None broadcasting

Why this matters

In many ML computations a single parameter tensor (e.g. a weight matrix W) is shared across an entire batch โ€” every example uses the same W. When you vectorize with jax.vmap, you must tell it which arguments are batched and which are shared. in_axes=None for a position means โ€œdo not slice this argument โ€” broadcast it as-is to every call of the inner function.โ€

This is the pattern behind vmap(grad(loss), in_axes=(None, None, 0, 0)) for per-example gradients, where the model parameters are shared and only the data is batched.

Worked mini-example

import jax, jax.numpy as jnp

W = jnp.array([[1.0, 0.0], [0.0, 1.0]])   # shape (2, 2) โ€” shared
x_batch = jnp.array([[1.0, 2.0], [3.0, 4.0]])  # shape (2, 2) โ€” batched

result = jax.vmap(lambda W, x: W @ x, in_axes=(None, 0))(W, x_batch)
# result โ†’ [[1.0, 2.0], [3.0, 4.0]]  shape (2, 2)
# Same W applied to each x separately.

Common pitfalls

  • Forgetting None and using (0, 0) โ€” vmap would slice W along axis 0, effectively applying row i of W to example i. Completely wrong for a shared matrix.
  • Confusing N and d_in. W shape is (d_out, d_in), x shape is (d_in,) for a single example. The batch axis 0 of x_batch is N, not d_in.
  • Forgetting that None applies to the whole argument. You cannot use None to broadcast along one axis while batching along another within the same tensor โ€” use regular broadcasting for that.

Problem

Implement affine_per_example(W, x_batch) that applies the shared weight matrix W to each example in the batch via jax.vmap.

  • W: 2-D jax array of shape (d_out, d_in) โ€” shared across all N.
  • x_batch: 2-D jax array of shape (N, d_in).
  • Returns: 2-D array of shape (N, d_out) where out[i] = W @ x_batch[i].

Use jax.vmap with in_axes=(None, 0).

Hints

jax vmap in-axes broadcast

Sign in to attempt this problem and view the solution.