hard primitives

shard_map with psum Collective

Why this matters

psum, pmean, and pmax are SPMD collectives โ€” all-reduce operations that aggregate values across every shard of a named mesh axis. They are the building block of distributed gradient synchronisation in data-parallel training: each device accumulates a local gradient, then a single psum call produces the global gradient on every device simultaneously.

def per_shard(y):
    local_grad = jnp.sum(y)              # local partial result
    return jax.lax.psum(local_grad, axis_name='data')  # global all-reduce

Worked mini-example

import jax, numpy as np, jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec
from jax.experimental.shard_map import shard_map

devices = np.array(jax.devices()[:1])
mesh = Mesh(devices, ('data',))

def per_shard(y):
    return jax.lax.psum(jnp.sum(y), axis_name='data')

f = shard_map(per_shard, mesh=mesh,
              in_specs=PartitionSpec('data'),
              out_specs=PartitionSpec())
print(f(jnp.array([1.0, 2.0, 3.0])))  # 6.0

Common pitfalls

  • Collective is BLOCKING. Every shard must reach the psum call before any can proceed. Conditionals that prevent some shards from reaching the collective cause a deadlock on multi-device hosts.
  • axis_name must match the mesh axis name exactly. Using 'batch' when the mesh axis is named 'data' raises a NameError.
  • out_specs=PartitionSpec() for a replicated (scalar) output. After psum, every shard holds the same global value โ€” it is replicated, not sharded, so the empty spec is correct.
  • Forgetting psum returns a per-shard partial sum, not the global sum. On 1 device it happens to be correct; on multiple devices it is wrong.

Problem

Implement shard_map_psum(x) that:

  1. Creates a 1-device Mesh with axis name 'data'.
  2. Defines a per-shard function that computes jnp.sum(y) locally and then calls jax.lax.psum(local, axis_name='data') to all-reduce.
  3. Wraps it with shard_map using in_specs=PartitionSpec('data') and out_specs=PartitionSpec() (replicated scalar output).
  4. Returns the result of calling the wrapped function on x.

Single-device caveat: the test runner has exactly 1 CPU device. On 1 device, psum is a no-op โ€” the result equals jnp.sum(x). On multiple devices, psum gathers partial sums from every shard.

  • x: 1-D JAX array.

Returns: scalar โ€” global sum of all elements of x.

Examples (not from the test set):

  • shard_map_psum(jnp.array([1.0, 2.0, 3.0])) โ†’ 6.0

Hints

jax shard-map psum collectives

Sign in to attempt this problem and view the solution.