hard primitives

shard_map Basics

Why this matters

shard_map is the modern replacement for pmap. Where pmap implicitly parallelises a batch axis, shard_map makes the sharding explicit: you choose a Mesh, describe the per-axis sharding with PartitionSpec, and the function you pass runs once per shard on the local shard data.

from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec

devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
f = shard_map(
    lambda y: y * 2.0,
    mesh=mesh,
    in_specs=PartitionSpec('data'),
    out_specs=PartitionSpec('data'),
)
result = f(x)   # each device doubles its shard; gather on exit

Worked mini-example

import jax, numpy as np, jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec
from jax.experimental.shard_map import shard_map

devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
f = shard_map(lambda y: y + 1.0, mesh=mesh,
              in_specs=PartitionSpec('data'),
              out_specs=PartitionSpec('data'))
print(f(jnp.array([3.0, 7.0])))  # [4. 8.]

Common pitfalls

  • in_specs/out_specs must align with array rank. A 1-D array maps to PartitionSpec('data') (one entry), not PartitionSpec('data', None).
  • Inner function sees LOCAL shard only. Don’t try to index the full array inside the shard_map body — only the local shard is visible.
  • out_specs shape must match what the inner function returns. If the inner function returns a 1-D array of local shard size, use PartitionSpec('data').
  • On 1 device, shard = full array, so the behaviour is identical to calling the inner function directly.

Problem

Implement shard_map_double(x) that:

  1. Creates a 1-device Mesh with axis name 'data'.
  2. Uses shard_map with in_specs=PartitionSpec('data') and out_specs=PartitionSpec('data') to apply lambda y: y * 2.0.
  3. Calls the resulting function on x and returns the result.

Single-device caveat: the test runner has exactly 1 CPU device. On 1 device the result is identical to x * 2. The lesson is the API pattern, which scales transparently to multiple devices.

  • x: 1-D JAX array.

Returns: 1-D array of the same shape as x, with every element doubled.

Examples (not from the test set):

  • shard_map_double(jnp.array([1.0, 2.0]))[2. 4.]

Hints

jax shard-map sharding

Sign in to attempt this problem and view the solution.