We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
lax.map vs vmap
Why this matters
lax.map(f, xs) applies f to each element of xs along the leading
axis. It looks like jax.vmap(f)(xs), but with a critical difference:
-
vmapvectorizes — it runs all batch elements in parallel, using memory proportional to the batch size times the per-element memory. -
lax.mapruns sequentially — one element at a time, using constant memory (plus one element’s working memory). This is equivalent to alax.scanwith no carry.
The trade-off is speed vs memory: vmap is faster for typical workloads; lax.map is essential when the batch is so large that vectorizing would OOM.
Worked mini-example
from jax import lax
import jax.numpy as jnp
x = jnp.array([1.0, 2.0, 3.0])
out = lax.map(lambda xi: xi * 2, x)
# → [2.0, 4.0, 6.0]
The function f receives a single element (same shape as xs[0]),
not the full batch.
Common pitfalls
- Using lax.map by default. It is sequentially slower than vmap. Use vmap unless you have a documented memory constraint.
-
f receives one element. Unlike writing a batched function,
ftakes a single slice — shapexs.shape[1:], notxs.shape. -
Stateful side effects. As with all JAX transforms,
fmust be pure. -
Output shape.
lax.mapstacks outputs along a new leading axis — shape is(len(xs),) + f_output_shape.
Problem
Implement lax_map_square(x) that squares each element using lax.map.
Do not use jnp.square, x ** 2 directly, or vmap — the point is
to learn the lax.map API.
-
x: 1-D jax array. - Returns: 1-D array same shape — element-wise square.
Hints
jax
lax-map
vmap
Sign in to attempt this problem and view the solution.