We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX shard_map Simulation
Why this matters
jax.experimental.shard_map.shard_map (a.k.a. shmap) is JAX’s
manual SPMD primitive: each device runs the function body on
its local shard, with no automatic global awareness. You write
the per-device logic — including any cross-device collectives —
explicitly. JAX glues the per-device outputs back into a global
array based on the out_specs.
Compared to pmap, shard_map is the modern composable
replacement: it works inside jit, supports arbitrary
PartitionSpecs (not just one axis), and the function body sees
only its own shard. Compared to jit with in_shardings /
out_shardings, shard_map is the manual escape hatch: when
you want explicit control over data movement, this is how.
Three big use cases
-
Data parallelism (this problem): split the batch across
devices, each runs its own forward.
shard_map(step, mesh, in_specs=P("data", None), out_specs=P("data", None)). -
Tensor parallelism: split a Linear’s kernel along its in or
out axis;
shard_mapruns the local matmul, withpsum-style collectives bridging. -
Pipeline parallelism: split the model depth-wise across
devices;
shard_mapruns each device’s slice of the pipeline.
The MENTAL MODEL is what matters: each device’s body runs on
only its local shard. There’s no automatic global reasoning —
you have to write jax.lax.psum, jax.lax.all_gather,
jax.lax.ppermute yourself when devices need to coordinate.
What the real shard_map looks like
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P
devices = mesh_utils.create_device_mesh((num_shards,))
mesh = Mesh(devices, axis_names=("data",))
def step(local_x):
# Inside this body, local_x has shape (batch_per_device, ...).
return model(local_x) # nnx Module call
sharded_step = shard_map(step, mesh,
in_specs=P("data", None),
out_specs=P("data", None))
out = sharded_step(global_x) # global_x is (full_batch, ...)
in_specs=P("data", None) says: “shard axis 0 along the data
mesh axis; replicate axis 1.” Each device’s step receives its
own (local_batch, ...) chunk.
out_specs=P("data", None) says: “concatenate the per-device
outputs along axis 0 to reconstruct the global output.”
Single-device simulation
On one device we can’t run a real shard_map, but we can simulate
the SHAPE of the computation in pure Python:
chunk = x.shape[0] // num_shards
outs = []
for i in range(num_shards):
x_shard = jax.lax.dynamic_slice_in_dim(x, i*chunk, chunk, axis=0)
outs.append(model(x_shard))
return jnp.concatenate(outs, axis=0)
The Python for i in range(num_shards) corresponds to “this loop
runs concurrently across num_shards devices in production.” On
our single device it’s a sequential loop that produces the same
array.
Why init the model ONCE, outside the loop
In production, all shards use the same params — they’re replicated across the data-parallel axis. Building the model inside the loop would (a) waste compute and (b) give each shard a different RNG, defeating the whole point.
Build model = nnx.Linear(in_dim, features, rngs=nnx.Rngs(seed))
outside the loop and call model(shard) inside. Each call
sees the same params; each call sees a different shard.
Why concatenate(axis=0) rebuilds the global output
Each shard’s output is (chunk, features). Stacking them along
axis 0 gives (num_shards * chunk, features) = (N, features) —
the global output, identical to what a single-device call on the
whole x would produce.
The whole point of shard_map is that out_specs=P("data", None)
is the declarative version of this concatenate. JAX figures out
the data movement; the function body works on local shapes.
Common pitfalls
- Re-initializing per shard: in production all shards use the same params. Don’t init inside the loop.
-
Forgetting
axis=0: shard along the batch axis (axis 0). Concat back along the same axis. -
Using
jax.numpyslicing wheredynamic_slice_in_dimis better: under JIT, Python integer slicing works (it’s a static slice), butdynamic_slice_in_dimmakes the intent explicit and works under any tracer. -
Forgetting that
shard_mapexpects the global input: the function body operates on local shards, but you call the sharded function with the full tensor. JAX does the splitting.
Problem
Implement shard_map_simulate(seed, x, num_shards, features):
-
Cast
seed,num_shards,featuresto ints. Usex.shape[-1]as the input dim. -
Build
model = nnx.Linear(in_dim, features, rngs=nnx.Rngs(seed))outside the loop. -
Compute
chunk = x.shape[0] // num_shards. -
Loop
i in range(num_shards): takejax.lax.dynamic_slice_in_dim(x, i*chunk, chunk, axis=0)and callmodel(x_shard); append to a list. -
out = jnp.concatenate(outs, axis=0). -
Return
out.reshape(-1)flattened.
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, in_dim)whereNis divisible bynum_shards. -
num_shards: float (cast to int). -
features: float (cast to int).
Output: 1-D, length N * features.
Hints
Sign in to attempt this problem and view the solution.