medium primitives

vmap with in_axes

Why this matters

jax.vmap vectorizes a function โ€” you write the logic for a single example, and vmap lifts it to run over an entire batch in parallel. The in_axes argument is the key knob: it tells vmap which axis of each input array carries the batch dimension.

Understanding in_axes is the first step toward more advanced patterns: nested vmap (all-pairs), broadcasting shared parameters (None), and composing vmap with jax.grad for per-example gradients.

Worked mini-example

import jax, jax.numpy as jnp

def square(x):
    return x ** 2

# Without vmap: square works on a scalar.
# With vmap: maps square over axis 0 of the input array.
result = jax.vmap(square)(jnp.array([1.0, 2.0, 3.0]))
# โ†’ [1.0, 4.0, 9.0]

For two batched arguments, specify in_axes=(0, 0) so vmap knows to slice along axis 0 of both inputs:

def add(a, b):
    return a + b

result = jax.vmap(add, in_axes=(0, 0))(
    jnp.array([1.0, 2.0]),
    jnp.array([10.0, 20.0]),
)
# โ†’ [11.0, 22.0]

Common pitfalls

  • in_axes length must match argument count. If your function has two arguments but you pass a single int, JAX will error.
  • Default is 0, not None. Omitting in_axes means every argument is batched over axis 0. Use None explicitly for arguments that are shared (not batched).
  • Wrong axis for column-batched data. If your batch dimension is axis 1 (columns), pass in_axes=1, not 0.

Problem

Implement batched_outer(a_batch, b_batch) that computes the outer product of each pair of corresponding rows.

  • a_batch, b_batch: 2-D jax arrays of shape (N, d).
  • Returns: 3-D array of shape (N, d, d) where out[i] = jnp.outer(a_batch[i], b_batch[i]).

Use jax.vmap with an explicit in_axes โ€” do not loop.

Hints

jax vmap in-axes

Sign in to attempt this problem and view the solution.