medium primitives

lax.map vs vmap

Why this matters

lax.map(f, xs) applies f to each element of xs along the leading axis. It looks like jax.vmap(f)(xs), but with a critical difference:

  • vmap vectorizes — it runs all batch elements in parallel, using memory proportional to the batch size times the per-element memory.
  • lax.map runs sequentially — one element at a time, using constant memory (plus one element’s working memory). This is equivalent to a lax.scan with no carry.

The trade-off is speed vs memory: vmap is faster for typical workloads; lax.map is essential when the batch is so large that vectorizing would OOM.

Worked mini-example

from jax import lax
import jax.numpy as jnp

x = jnp.array([1.0, 2.0, 3.0])
out = lax.map(lambda xi: xi * 2, x)
# → [2.0, 4.0, 6.0]

The function f receives a single element (same shape as xs[0]), not the full batch.

Common pitfalls

  • Using lax.map by default. It is sequentially slower than vmap. Use vmap unless you have a documented memory constraint.
  • f receives one element. Unlike writing a batched function, f takes a single slice — shape xs.shape[1:], not xs.shape.
  • Stateful side effects. As with all JAX transforms, f must be pure.
  • Output shape. lax.map stacks outputs along a new leading axis — shape is (len(xs),) + f_output_shape.

Problem

Implement lax_map_square(x) that squares each element using lax.map. Do not use jnp.square, x ** 2 directly, or vmap — the point is to learn the lax.map API.

  • x: 1-D jax array.
  • Returns: 1-D array same shape — element-wise square.

Hints

jax lax-map vmap

Sign in to attempt this problem and view the solution.