hard primitives

shard_map vs vmap

Why this matters

Both vmap and shard_map produce parallel execution, but they differ in their execution model:

  • vmap is logically batched β€” one compilation trace with vectorised operations. The compiler decides how to execute it. No mesh required.
  • shard_map is physically sharded β€” explicitly bound to a Mesh. Each device runs the per-shard function on its local slice. The programmer controls the partitioning.

On 1 device the two produce numerically identical results. On multiple devices shard_map can exploit data parallelism that vmap alone cannot express.

Worked mini-example

import jax, numpy as np, jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec
from jax.experimental.shard_map import shard_map

x = jnp.array([[1.0, 2.0], [3.0, 4.0]])   # shape (2, 2)

# vmap: per-row max
vmap_result = jax.vmap(jnp.max)(x)          # [2., 4.]

# shard_map: same per-row max
devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
def per_shard(rows):
    return jax.vmap(jnp.max)(rows)
sm_result = shard_map(per_shard, mesh=mesh,
                      in_specs=PartitionSpec('data'),
                      out_specs=PartitionSpec('data'))(x)

jnp.allclose(sm_result, vmap_result)  # True

Common pitfalls

  • shard_map needs a Mesh; vmap does not. Don’t confuse the two APIs.
  • in_specs=PartitionSpec('data') shards axis 0. For a 2-D array (B, d), this splits the row (batch) axis across devices; the feature axis d is implicitly replicated.
  • Inner shard_map function receives ALL local rows. Use vmap inside the per-shard function to process each row independently, just as you would outside shard_map.
  • jnp.stack([a, b], axis=0) requires matching shapes. Both sm_result and vmap_result should be 1-D arrays of length B.

Problem

Implement shard_map_vs_vmap(x) that computes the per-row sum-of-squares of a 2-D array x of shape (B, d) using both methods, then stacks the results:

  1. shard_map path: wrap a per-shard function that calls jax.vmap(lambda r: jnp.sum(r ** 2))(rows) inside shard_map with in_specs=PartitionSpec('data') and out_specs=PartitionSpec('data').
  2. vmap path: directly call jax.vmap(lambda r: jnp.sum(r ** 2))(x).
  3. Stack as jnp.stack([sm_result, vmap_result], axis=0).

Single-device caveat: the test runner has exactly 1 CPU device. Both paths produce the same numerical result, so the output is a (2, B) array where both rows are identical.

  • x: 2-D JAX array of shape (B, d).

Returns: 2-D array of shape (2, B) β€” [shard_map_result, vmap_result].

Examples (not from the test set):

  • shard_map_vs_vmap(jnp.array([[1.0, 0.0], [0.0, 1.0]])) β†’ [[1.,1.],[1.,1.]]

Hints

jax shard-map vmap comparison

Sign in to attempt this problem and view the solution.