We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
vmap with in_axes
Why this matters
jax.vmap vectorizes a function โ you write the logic for a single
example, and vmap lifts it to run over an entire batch in parallel. The
in_axes argument is the key knob: it tells vmap which axis of each input
array carries the batch dimension.
Understanding in_axes is the first step toward more advanced patterns:
nested vmap (all-pairs), broadcasting shared parameters (None), and
composing vmap with jax.grad for per-example gradients.
Worked mini-example
import jax, jax.numpy as jnp
def square(x):
return x ** 2
# Without vmap: square works on a scalar.
# With vmap: maps square over axis 0 of the input array.
result = jax.vmap(square)(jnp.array([1.0, 2.0, 3.0]))
# โ [1.0, 4.0, 9.0]
For two batched arguments, specify in_axes=(0, 0) so vmap knows to
slice along axis 0 of both inputs:
def add(a, b):
return a + b
result = jax.vmap(add, in_axes=(0, 0))(
jnp.array([1.0, 2.0]),
jnp.array([10.0, 20.0]),
)
# โ [11.0, 22.0]
Common pitfalls
-
in_axeslength must match argument count. If your function has two arguments but you pass a single int, JAX will error. -
Default is 0, not None. Omitting
in_axesmeans every argument is batched over axis 0. UseNoneexplicitly for arguments that are shared (not batched). -
Wrong axis for column-batched data. If your batch dimension is axis 1
(columns), pass
in_axes=1, not 0.
Problem
Implement batched_outer(a_batch, b_batch) that computes the outer product
of each pair of corresponding rows.
-
a_batch,b_batch: 2-D jax arrays of shape(N, d). -
Returns: 3-D array of shape
(N, d, d)whereout[i] = jnp.outer(a_batch[i], b_batch[i]).
Use jax.vmap with an explicit in_axes โ do not loop.
Hints
Sign in to attempt this problem and view the solution.