hard primitives

NNX shard_map Simulation

Why this matters

jax.experimental.shard_map.shard_map (a.k.a. shmap) is JAX’s manual SPMD primitive: each device runs the function body on its local shard, with no automatic global awareness. You write the per-device logic — including any cross-device collectives — explicitly. JAX glues the per-device outputs back into a global array based on the out_specs.

Compared to pmap, shard_map is the modern composable replacement: it works inside jit, supports arbitrary PartitionSpecs (not just one axis), and the function body sees only its own shard. Compared to jit with in_shardings / out_shardings, shard_map is the manual escape hatch: when you want explicit control over data movement, this is how.

Three big use cases

  • Data parallelism (this problem): split the batch across devices, each runs its own forward. shard_map(step, mesh, in_specs=P("data", None), out_specs=P("data", None)).
  • Tensor parallelism: split a Linear’s kernel along its in or out axis; shard_map runs the local matmul, with psum-style collectives bridging.
  • Pipeline parallelism: split the model depth-wise across devices; shard_map runs each device’s slice of the pipeline.

The MENTAL MODEL is what matters: each device’s body runs on only its local shard. There’s no automatic global reasoning — you have to write jax.lax.psum, jax.lax.all_gather, jax.lax.ppermute yourself when devices need to coordinate.

What the real shard_map looks like

from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P

devices = mesh_utils.create_device_mesh((num_shards,))
mesh = Mesh(devices, axis_names=("data",))

def step(local_x):
    # Inside this body, local_x has shape (batch_per_device, ...).
    return model(local_x)        # nnx Module call

sharded_step = shard_map(step, mesh,
                         in_specs=P("data", None),
                         out_specs=P("data", None))
out = sharded_step(global_x)   # global_x is (full_batch, ...)

in_specs=P("data", None) says: “shard axis 0 along the data mesh axis; replicate axis 1.” Each device’s step receives its own (local_batch, ...) chunk.

out_specs=P("data", None) says: “concatenate the per-device outputs along axis 0 to reconstruct the global output.”

Single-device simulation

On one device we can’t run a real shard_map, but we can simulate the SHAPE of the computation in pure Python:

chunk = x.shape[0] // num_shards
outs = []
for i in range(num_shards):
    x_shard = jax.lax.dynamic_slice_in_dim(x, i*chunk, chunk, axis=0)
    outs.append(model(x_shard))
return jnp.concatenate(outs, axis=0)

The Python for i in range(num_shards) corresponds to “this loop runs concurrently across num_shards devices in production.” On our single device it’s a sequential loop that produces the same array.

Why init the model ONCE, outside the loop

In production, all shards use the same params — they’re replicated across the data-parallel axis. Building the model inside the loop would (a) waste compute and (b) give each shard a different RNG, defeating the whole point.

Build model = nnx.Linear(in_dim, features, rngs=nnx.Rngs(seed)) outside the loop and call model(shard) inside. Each call sees the same params; each call sees a different shard.

Why concatenate(axis=0) rebuilds the global output

Each shard’s output is (chunk, features). Stacking them along axis 0 gives (num_shards * chunk, features) = (N, features) — the global output, identical to what a single-device call on the whole x would produce.

The whole point of shard_map is that out_specs=P("data", None) is the declarative version of this concatenate. JAX figures out the data movement; the function body works on local shapes.

Common pitfalls

  • Re-initializing per shard: in production all shards use the same params. Don’t init inside the loop.
  • Forgetting axis=0: shard along the batch axis (axis 0). Concat back along the same axis.
  • Using jax.numpy slicing where dynamic_slice_in_dim is better: under JIT, Python integer slicing works (it’s a static slice), but dynamic_slice_in_dim makes the intent explicit and works under any tracer.
  • Forgetting that shard_map expects the global input: the function body operates on local shards, but you call the sharded function with the full tensor. JAX does the splitting.

Problem

Implement shard_map_simulate(seed, x, num_shards, features):

  1. Cast seed, num_shards, features to ints. Use x.shape[-1] as the input dim.
  2. Build model = nnx.Linear(in_dim, features, rngs=nnx.Rngs(seed)) outside the loop.
  3. Compute chunk = x.shape[0] // num_shards.
  4. Loop i in range(num_shards): take jax.lax.dynamic_slice_in_dim(x, i*chunk, chunk, axis=0) and call model(x_shard); append to a list.
  5. out = jnp.concatenate(outs, axis=0).
  6. Return out.reshape(-1) flattened.

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, in_dim) where N is divisible by num_shards.
  • num_shards: float (cast to int).
  • features: float (cast to int).

Output: 1-D, length N * features.

Hints

flax nnx shard-map spmd sharding

Sign in to attempt this problem and view the solution.