We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
shard_map Basics
Why this matters
shard_map is the modern replacement for pmap. Where pmap implicitly
parallelises a batch axis, shard_map makes the sharding explicit:
you choose a Mesh, describe the per-axis sharding with PartitionSpec,
and the function you pass runs once per shard on the local shard data.
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec
devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
f = shard_map(
lambda y: y * 2.0,
mesh=mesh,
in_specs=PartitionSpec('data'),
out_specs=PartitionSpec('data'),
)
result = f(x) # each device doubles its shard; gather on exit
Worked mini-example
import jax, numpy as np, jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec
from jax.experimental.shard_map import shard_map
devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
f = shard_map(lambda y: y + 1.0, mesh=mesh,
in_specs=PartitionSpec('data'),
out_specs=PartitionSpec('data'))
print(f(jnp.array([3.0, 7.0]))) # [4. 8.]
Common pitfalls
-
in_specs/out_specsmust align with array rank. A 1-D array maps toPartitionSpec('data')(one entry), notPartitionSpec('data', None). -
Inner function sees LOCAL shard only. Don’t try to index the full array
inside the
shard_mapbody — only the local shard is visible. -
out_specsshape must match what the inner function returns. If the inner function returns a 1-D array of local shard size, usePartitionSpec('data'). - On 1 device, shard = full array, so the behaviour is identical to calling the inner function directly.
Problem
Implement shard_map_double(x) that:
-
Creates a 1-device
Meshwith axis name'data'. -
Uses
shard_mapwithin_specs=PartitionSpec('data')andout_specs=PartitionSpec('data')to applylambda y: y * 2.0. -
Calls the resulting function on
xand returns the result.
Single-device caveat: the test runner has exactly 1 CPU device. On 1
device the result is identical to x * 2. The lesson is the API pattern, which
scales transparently to multiple devices.
-
x: 1-D JAX array.
Returns: 1-D array of the same shape as x, with every element doubled.
Examples (not from the test set):
-
shard_map_double(jnp.array([1.0, 2.0]))→[2. 4.]
Hints
Sign in to attempt this problem and view the solution.