We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
shard_map Simulation — Manual SPMD
Why this matters
jax.experimental.shard_map.shard_map (a.k.a. shmap) is JAX’s
manual SPMD primitive. It says: “I have an array sharded
across N devices; I want to run this function on EACH SHARD
SEPARATELY, with the local shard’s data, and then have JAX glue
the per-device outputs back into a global array.”
Compare to pmap: pmap is the older “data-parallel” wrapper
that handles a single batch axis. shard_map is the modern,
composable replacement — it works inside jit, supports arbitrary
sharding specs (not just one axis), and the function body sees
only its own shard, not the whole array.
Real production uses:
-
Data parallelism (the case in this problem): split the
batch across devices, each runs its own forward/backward.
shard_map(step_fn, mesh, in_specs=P("data", None), out_specs=P("data", None)). -
Tensor parallelism: split a Dense’s kernel along its
input or output axis;
shard_mapruns the local-shape matmul on each device, withpsum-style collectives bridging. - Pipeline parallelism: split the model depth-wise; each device runs its slice of the pipeline.
The MENTAL MODEL is what matters: each device’s body runs on
only its local shard. There’s no automatic global-aware
behavior — you have to write collectives (jax.lax.psum,
jax.lax.all_gather, jax.lax.ppermute) yourself when devices
need to coordinate.
What the real shard_map looks like
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P
devices = mesh_utils.create_device_mesh((num_shards,))
mesh = Mesh(devices, axis_names=("data",))
def step(local_x):
# Inside this body, local_x has shape (batch_per_device, ...).
return model.apply(params, local_x)
sharded_step = shard_map(step, mesh,
in_specs=P("data", None),
out_specs=P("data", None))
out = sharded_step(global_x) # global_x is (full_batch, ...)
in_specs=P("data", None) says: “shard axis 0 along the data
mesh axis; replicate axis 1.” Each device’s step receives its
own (local_batch, ...) chunk.
out_specs=P("data", None) says: “concatenate the per-device
outputs along axis 0 to reconstruct the global output.”
What we simulate here
On a single-device test machine, we can’t run a real shard_map
that spans multiple physical accelerators. But we can simulate
the SHAPE of the computation in pure Python:
chunk_size = x.shape[0] // num_shards
outs = []
for i in range(num_shards):
shard = jax.lax.dynamic_slice_in_dim(x, i * chunk_size, chunk_size, axis=0)
outs.append(model.apply(params, shard))
return jnp.concatenate(outs, axis=0)
The Python for i in range(num_shards) corresponds to “this loop
runs concurrently across num_shards devices in production.” On
our single device it’s a sequential loop that produces the same
array.
Why init on a single shard
To know nn.Dense(features).init(rng, x)‘s output shapes, Flax
needs to see the shape of the input the layer will receive at
apply. Since each shard’s apply receives a (chunk_size, in_dim)
array, we init using a representative slice with that shape:
chunk_size = x.shape[0] // num_shards
first_shard = jax.lax.dynamic_slice_in_dim(x, 0, chunk_size, axis=0)
params = model.init(rng, first_shard)
For Dense, kernel‘s shape only depends on the LAST dim — it
doesn’t actually depend on chunk_size. But for layers that DO
care (BatchNorm with a fixed batch axis would), initing on a
correctly-shaped slice matters. We follow the safer pattern.
dynamic_slice_in_dim(x, start, size, axis) is the JAX-friendly
way to take a contiguous slice; alternatively
x[i*chunk_size:(i+1)*chunk_size] works on a concrete array.
Why concatenate(axis=0) rebuilds the global output
Each shard’s output is (chunk_size, features). Stacking them
along axis 0 gives (num_shards * chunk_size, features) = (N, features) — the global output, identical to what a
single-device apply would produce.
The whole point of shard_map is that out_specs=P("data", None)
is the declarative version of this concatenate. JAX figures out
the data movement; the function body works on local shapes.
Common pitfalls
- Forgetting axis=0: shard along the batch axis (axis 0). Concat back along the same axis.
- Re-initializing per shard: in production all shards use the same params (they’re replicated). Don’t init inside the loop. Init once (outside the loop), apply the same params to each shard.
-
Splitting unevenly: if
x.shape[0]isn’t divisible bynum_shards, you’d need padding. The test cases ensure clean division. -
Real shard_map needs a
mesh: the simulation skips this. Production code must construct aMeshand pass it explicitly.
Problem
Implement shard_map_simulate(seed, x, num_shards, features):
-
Cast
num_shardsandfeaturesto Pythonint. -
Build
model = nn.Dense(features). -
Compute
chunk_size = x.shape[0] // num_shards. -
Init the model on the first shard (a
(chunk_size, in_dim)slice). -
Loop
num_shardstimes: slice the i-th chunk, runmodel.apply(params, shard), collect the output. -
jnp.concatenate(outs, axis=0). Flatten the result.
Inputs:
-
seed: float (cast to int) — PRNG seed. -
x: 2-D(N, in_dim)whereNis divisible bynum_shards. -
num_shards: float (cast to int) — number of simulated shards. -
features: float (cast to int) — Dense output dim.
Output: 1-D, length N * features.
Hints
Sign in to attempt this problem and view the solution.