hard primitives

Nested vmap (vmap-of-vmap)

Why this matters

Many machine learning algorithms require computing a function over all pairs (x_i, y_j) from two sets β€” pairwise distances, attention scores, kernel matrices, and cross-similarity tables. Nested jax.vmap is the canonical JAX pattern for these all-pairs computations: write the function for a single pair, then vectorize twice to cover all (N Γ— M) combinations.

Nested vmap is also a stepping stone toward understanding how JAX transforms compose β€” each vmap introduces one batch dimension, and they stack cleanly.

Worked mini-example

For N=2, M=2, the unrolled computation is:

out[0, 0] = dist(X[0], Y[0])
out[0, 1] = dist(X[0], Y[1])
out[1, 0] = dist(X[1], Y[0])
out[1, 1] = dist(X[1], Y[1])

The nested vmap structure:

import jax, jax.numpy as jnp

def dist(x, y):
    return jnp.linalg.norm(x - y)

# Inner vmap: fix x, map over all y  β†’ shape (M,)
inner = jax.vmap(dist, in_axes=(None, 0))

# Outer vmap: map over all x, fix y  β†’ shape (N, M)
result = jax.vmap(inner, in_axes=(0, None))(X, Y)

Common pitfalls

  • Getting in_axes backwards. INNER vmap fixes x (None) and maps over y (axis 0). OUTER vmap maps over x (axis 0) and fixes y (None). Swapping them gives a transposed result.
  • Confusing nested vmap with two separate vmaps. The outer vmap wraps the inner vmap β€” it’s one expression, not two sequential calls.
  • Forgetting that in_axes in the outer vmap controls the argument to the INNER vmap, not to the innermost function. The inner vmap already handles Y; the outer just needs to handle X.

Problem

Implement pairwise_distances(X, Y) that computes all NΓ—M pairwise L2 distances via nested jax.vmap.

  • X: 2-D jax array of shape (N, d).
  • Y: 2-D jax array of shape (M, d).
  • Returns: 2-D array of shape (N, M) where out[i, j] = β€–X[i] - Y[j]β€–β‚‚.

Do not use loops or jnp.einsum β€” use nested vmap.

Hints

jax vmap nested

Sign in to attempt this problem and view the solution.