We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Nested vmap (vmap-of-vmap)
Why this matters
Many machine learning algorithms require computing a function over
all pairs (x_i, y_j) from two sets β pairwise distances, attention
scores, kernel matrices, and cross-similarity tables. Nested jax.vmap
is the canonical JAX pattern for these all-pairs computations: write
the function for a single pair, then vectorize twice to cover all
(N Γ M) combinations.
Nested vmap is also a stepping stone toward understanding how JAX transforms compose β each vmap introduces one batch dimension, and they stack cleanly.
Worked mini-example
For N=2, M=2, the unrolled computation is:
out[0, 0] = dist(X[0], Y[0])
out[0, 1] = dist(X[0], Y[1])
out[1, 0] = dist(X[1], Y[0])
out[1, 1] = dist(X[1], Y[1])
The nested vmap structure:
import jax, jax.numpy as jnp
def dist(x, y):
return jnp.linalg.norm(x - y)
# Inner vmap: fix x, map over all y β shape (M,)
inner = jax.vmap(dist, in_axes=(None, 0))
# Outer vmap: map over all x, fix y β shape (N, M)
result = jax.vmap(inner, in_axes=(0, None))(X, Y)
Common pitfalls
-
Getting in_axes backwards. INNER vmap fixes
x(None) and maps overy(axis 0). OUTER vmap maps overx(axis 0) and fixesy(None). Swapping them gives a transposed result. - Confusing nested vmap with two separate vmaps. The outer vmap wraps the inner vmap β itβs one expression, not two sequential calls.
-
Forgetting that
in_axesin the outer vmap controls the argument to the INNER vmap, not to the innermost function. The inner vmap already handles Y; the outer just needs to handle X.
Problem
Implement pairwise_distances(X, Y) that computes all NΓM pairwise
L2 distances via nested jax.vmap.
-
X: 2-D jax array of shape(N, d). -
Y: 2-D jax array of shape(M, d). -
Returns: 2-D array of shape
(N, M)whereout[i, j] = βX[i] - Y[j]ββ.
Do not use loops or jnp.einsum β use nested vmap.
Hints
Sign in to attempt this problem and view the solution.