We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
shard_map vs vmap
Why this matters
Both vmap and shard_map produce parallel execution, but they differ in
their execution model:
-
vmapis logically batched β one compilation trace with vectorised operations. The compiler decides how to execute it. No mesh required. -
shard_mapis physically sharded β explicitly bound to aMesh. Each device runs the per-shard function on its local slice. The programmer controls the partitioning.
On 1 device the two produce numerically identical results. On multiple devices
shard_map can exploit data parallelism that vmap alone cannot express.
Worked mini-example
import jax, numpy as np, jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec
from jax.experimental.shard_map import shard_map
x = jnp.array([[1.0, 2.0], [3.0, 4.0]]) # shape (2, 2)
# vmap: per-row max
vmap_result = jax.vmap(jnp.max)(x) # [2., 4.]
# shard_map: same per-row max
devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
def per_shard(rows):
return jax.vmap(jnp.max)(rows)
sm_result = shard_map(per_shard, mesh=mesh,
in_specs=PartitionSpec('data'),
out_specs=PartitionSpec('data'))(x)
jnp.allclose(sm_result, vmap_result) # True
Common pitfalls
-
shard_mapneeds aMesh;vmapdoes not. Donβt confuse the two APIs. -
in_specs=PartitionSpec('data')shards axis 0. For a 2-D array(B, d), this splits the row (batch) axis across devices; the feature axisdis implicitly replicated. -
Inner
shard_mapfunction receives ALL local rows. Usevmapinside the per-shard function to process each row independently, just as you would outsideshard_map. -
jnp.stack([a, b], axis=0)requires matching shapes. Bothsm_resultandvmap_resultshould be 1-D arrays of lengthB.
Problem
Implement shard_map_vs_vmap(x) that computes the per-row sum-of-squares
of a 2-D array x of shape (B, d) using both methods, then stacks the
results:
-
shard_map path: wrap a per-shard function that calls
jax.vmap(lambda r: jnp.sum(r ** 2))(rows)insideshard_mapwithin_specs=PartitionSpec('data')andout_specs=PartitionSpec('data'). -
vmap path: directly call
jax.vmap(lambda r: jnp.sum(r ** 2))(x). -
Stack as
jnp.stack([sm_result, vmap_result], axis=0).
Single-device caveat: the test runner has exactly 1 CPU device. Both paths
produce the same numerical result, so the output is a (2, B) array where
both rows are identical.
-
x: 2-D JAX array of shape(B, d).
Returns: 2-D array of shape (2, B) β [shard_map_result, vmap_result].
Examples (not from the test set):
-
shard_map_vs_vmap(jnp.array([[1.0, 0.0], [0.0, 1.0]]))β[[1.,1.],[1.,1.]]
Hints
Sign in to attempt this problem and view the solution.