hard primitives

shard_map Simulation — Manual SPMD

Why this matters

jax.experimental.shard_map.shard_map (a.k.a. shmap) is JAX’s manual SPMD primitive. It says: “I have an array sharded across N devices; I want to run this function on EACH SHARD SEPARATELY, with the local shard’s data, and then have JAX glue the per-device outputs back into a global array.”

Compare to pmap: pmap is the older “data-parallel” wrapper that handles a single batch axis. shard_map is the modern, composable replacement — it works inside jit, supports arbitrary sharding specs (not just one axis), and the function body sees only its own shard, not the whole array.

Real production uses:

  • Data parallelism (the case in this problem): split the batch across devices, each runs its own forward/backward. shard_map(step_fn, mesh, in_specs=P("data", None), out_specs=P("data", None)).
  • Tensor parallelism: split a Dense’s kernel along its input or output axis; shard_map runs the local-shape matmul on each device, with psum-style collectives bridging.
  • Pipeline parallelism: split the model depth-wise; each device runs its slice of the pipeline.

The MENTAL MODEL is what matters: each device’s body runs on only its local shard. There’s no automatic global-aware behavior — you have to write collectives (jax.lax.psum, jax.lax.all_gather, jax.lax.ppermute) yourself when devices need to coordinate.

What the real shard_map looks like

from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P

devices = mesh_utils.create_device_mesh((num_shards,))
mesh = Mesh(devices, axis_names=("data",))

def step(local_x):
    # Inside this body, local_x has shape (batch_per_device, ...).
    return model.apply(params, local_x)

sharded_step = shard_map(step, mesh,
                         in_specs=P("data", None),
                         out_specs=P("data", None))
out = sharded_step(global_x)   # global_x is (full_batch, ...)

in_specs=P("data", None) says: “shard axis 0 along the data mesh axis; replicate axis 1.” Each device’s step receives its own (local_batch, ...) chunk.

out_specs=P("data", None) says: “concatenate the per-device outputs along axis 0 to reconstruct the global output.”

What we simulate here

On a single-device test machine, we can’t run a real shard_map that spans multiple physical accelerators. But we can simulate the SHAPE of the computation in pure Python:

chunk_size = x.shape[0] // num_shards
outs = []
for i in range(num_shards):
    shard = jax.lax.dynamic_slice_in_dim(x, i * chunk_size, chunk_size, axis=0)
    outs.append(model.apply(params, shard))
return jnp.concatenate(outs, axis=0)

The Python for i in range(num_shards) corresponds to “this loop runs concurrently across num_shards devices in production.” On our single device it’s a sequential loop that produces the same array.

Why init on a single shard

To know nn.Dense(features).init(rng, x)‘s output shapes, Flax needs to see the shape of the input the layer will receive at apply. Since each shard’s apply receives a (chunk_size, in_dim) array, we init using a representative slice with that shape:

chunk_size = x.shape[0] // num_shards
first_shard = jax.lax.dynamic_slice_in_dim(x, 0, chunk_size, axis=0)
params = model.init(rng, first_shard)

For Dense, kernel‘s shape only depends on the LAST dim — it doesn’t actually depend on chunk_size. But for layers that DO care (BatchNorm with a fixed batch axis would), initing on a correctly-shaped slice matters. We follow the safer pattern.

dynamic_slice_in_dim(x, start, size, axis) is the JAX-friendly way to take a contiguous slice; alternatively x[i*chunk_size:(i+1)*chunk_size] works on a concrete array.

Why concatenate(axis=0) rebuilds the global output

Each shard’s output is (chunk_size, features). Stacking them along axis 0 gives (num_shards * chunk_size, features) = (N, features) — the global output, identical to what a single-device apply would produce.

The whole point of shard_map is that out_specs=P("data", None) is the declarative version of this concatenate. JAX figures out the data movement; the function body works on local shapes.

Common pitfalls

  • Forgetting axis=0: shard along the batch axis (axis 0). Concat back along the same axis.
  • Re-initializing per shard: in production all shards use the same params (they’re replicated). Don’t init inside the loop. Init once (outside the loop), apply the same params to each shard.
  • Splitting unevenly: if x.shape[0] isn’t divisible by num_shards, you’d need padding. The test cases ensure clean division.
  • Real shard_map needs a mesh: the simulation skips this. Production code must construct a Mesh and pass it explicitly.

Problem

Implement shard_map_simulate(seed, x, num_shards, features):

  1. Cast num_shards and features to Python int.
  2. Build model = nn.Dense(features).
  3. Compute chunk_size = x.shape[0] // num_shards.
  4. Init the model on the first shard (a (chunk_size, in_dim) slice).
  5. Loop num_shards times: slice the i-th chunk, run model.apply(params, shard), collect the output.
  6. jnp.concatenate(outs, axis=0). Flatten the result.

Inputs:

  • seed: float (cast to int) — PRNG seed.
  • x: 2-D (N, in_dim) where N is divisible by num_shards.
  • num_shards: float (cast to int) — number of simulated shards.
  • features: float (cast to int) — Dense output dim.

Output: 1-D, length N * features.

Hints

flax shard-map spmd sharding

Sign in to attempt this problem and view the solution.