We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
vmap with out_axes
Why this matters
out_axes is the output counterpart to in_axes. Just as in_axes
controls which axis of each input carries the batch dimension,
out_axes controls where the mapped (batch) axis appears in each
output. The default is 0 — batch axis leading — but you can place it
anywhere in the output shape.
A common use-case: your data is stored in column-major form
(batch along axis 1), so you need in_axes=1 to map over columns.
You can then either transpose the result or use out_axes=0 to land
the batch axis back at the front.
Worked mini-example
import jax, jax.numpy as jnp
# matrix shape: (d, N) — batch dimension is axis 1 (columns)
matrix = jnp.array([[1.0, 2.0], [3.0, 4.0]]) # shape (2, 2)
# Map over columns (axis 1), return results stacked at axis 0
col_norms = jax.vmap(jnp.linalg.norm, in_axes=1, out_axes=0)(matrix)
# col_norms → [norm([1,3]), norm([2,4])] = [√10, √20] shape: (2,)
Common pitfalls
-
in_axesandout_axesare independent. Settingin_axes=1does NOT automatically move the output axis. If you omitout_axes, the default 0 still applies — which is usually what you want. -
Wrong axis = wrong shape. If you use
in_axes=0when the batch dimension is actually axis 1, JAX silently batches over rows instead of columns — you’ll get a wrong-shape or wrong-value result. -
Scalar outputs ignore
out_axes. If the inner function returns a scalar, there’s no output axis to place;out_axesis irrelevant.
Problem
Implement vmap_with_out_axis(matrix) that computes the sum of
squares per column of a matrix whose batch dimension is axis 1.
-
matrix: 2-D jax array of shape(d, N). -
Returns: 1-D array of shape
(N,)whereout[j] = sum(matrix[:, j] ** 2).
Use jax.vmap with explicit in_axes=1 and out_axes=0.
Hints
Sign in to attempt this problem and view the solution.