medium primitives

vmap with out_axes

Why this matters

out_axes is the output counterpart to in_axes. Just as in_axes controls which axis of each input carries the batch dimension, out_axes controls where the mapped (batch) axis appears in each output. The default is 0 — batch axis leading — but you can place it anywhere in the output shape.

A common use-case: your data is stored in column-major form (batch along axis 1), so you need in_axes=1 to map over columns. You can then either transpose the result or use out_axes=0 to land the batch axis back at the front.

Worked mini-example

import jax, jax.numpy as jnp

# matrix shape: (d, N) — batch dimension is axis 1 (columns)
matrix = jnp.array([[1.0, 2.0], [3.0, 4.0]])  # shape (2, 2)

# Map over columns (axis 1), return results stacked at axis 0
col_norms = jax.vmap(jnp.linalg.norm, in_axes=1, out_axes=0)(matrix)
# col_norms → [norm([1,3]), norm([2,4])] = [√10, √20]  shape: (2,)

Common pitfalls

  • in_axes and out_axes are independent. Setting in_axes=1 does NOT automatically move the output axis. If you omit out_axes, the default 0 still applies — which is usually what you want.
  • Wrong axis = wrong shape. If you use in_axes=0 when the batch dimension is actually axis 1, JAX silently batches over rows instead of columns — you’ll get a wrong-shape or wrong-value result.
  • Scalar outputs ignore out_axes. If the inner function returns a scalar, there’s no output axis to place; out_axes is irrelevant.

Problem

Implement vmap_with_out_axis(matrix) that computes the sum of squares per column of a matrix whose batch dimension is axis 1.

  • matrix: 2-D jax array of shape (d, N).
  • Returns: 1-D array of shape (N,) where out[j] = sum(matrix[:, j] ** 2).

Use jax.vmap with explicit in_axes=1 and out_axes=0.

Hints

jax vmap out-axes

Sign in to attempt this problem and view the solution.