We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
shard_map with psum Collective
Why this matters
psum, pmean, and pmax are SPMD collectives โ all-reduce operations
that aggregate values across every shard of a named mesh axis. They are the
building block of distributed gradient synchronisation in data-parallel training:
each device accumulates a local gradient, then a single psum call produces the
global gradient on every device simultaneously.
def per_shard(y):
local_grad = jnp.sum(y) # local partial result
return jax.lax.psum(local_grad, axis_name='data') # global all-reduce
Worked mini-example
import jax, numpy as np, jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec
from jax.experimental.shard_map import shard_map
devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))
def per_shard(y):
return jax.lax.psum(jnp.sum(y), axis_name='data')
f = shard_map(per_shard, mesh=mesh,
in_specs=PartitionSpec('data'),
out_specs=PartitionSpec())
print(f(jnp.array([1.0, 2.0, 3.0]))) # 6.0
Common pitfalls
-
Collective is BLOCKING. Every shard must reach the
psumcall before any can proceed. Conditionals that prevent some shards from reaching the collective cause a deadlock on multi-device hosts. -
axis_namemust match the mesh axis name exactly. Using'batch'when the mesh axis is named'data'raises aNameError. -
out_specs=PartitionSpec()for a replicated (scalar) output. Afterpsum, every shard holds the same global value โ it is replicated, not sharded, so the empty spec is correct. -
Forgetting
psumreturns a per-shard partial sum, not the global sum. On 1 device it happens to be correct; on multiple devices it is wrong.
Problem
Implement shard_map_psum(x) that:
-
Creates a 1-device
Meshwith axis name'data'. -
Defines a per-shard function that computes
jnp.sum(y)locally and then callsjax.lax.psum(local, axis_name='data')to all-reduce. -
Wraps it with
shard_mapusingin_specs=PartitionSpec('data')andout_specs=PartitionSpec()(replicated scalar output). -
Returns the result of calling the wrapped function on
x.
Single-device caveat: the test runner has exactly 1 CPU device. On 1
device, psum is a no-op โ the result equals jnp.sum(x). On multiple
devices, psum gathers partial sums from every shard.
-
x: 1-D JAX array.
Returns: scalar โ global sum of all elements of x.
Examples (not from the test set):
-
shard_map_psum(jnp.array([1.0, 2.0, 3.0]))โ6.0
Hints
Sign in to attempt this problem and view the solution.