We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Data-Parallel Step
Why this matters
Data parallelism is the bread-and-butter of multi-device training: every device holds a full copy of the model, but only a shard of the batch. Each device computes its forward and backward independently; an all-reduce across devices averages the gradients before the optimizer step.
Result: linear throughput scaling. 8 devices = 8× the batch size = 8× the effective compute per step. This is the cheapest, simplest scaling lever — and it’s what you reach for first when your model already fits on one device but you want more throughput.
The production recipe (multi-device)
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = Mesh(devices, axis_names=("data",))
@jax.jit
def train_step(model, optimizer, x, y):
# x, y are sharded along axis 0 by P("data", None).
def loss_fn(model, x, y):
return jnp.mean((model(x) - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
# All-reduce mean across the data mesh axis.
grads = jax.tree_util.tree_map(
lambda g: jax.lax.pmean(g, axis_name="data"), grads,
)
optimizer.update(model, grads)
return loss
sharded_step = shard_map(
train_step, mesh,
in_specs=(P(), P(), P("data", None), P("data", None)),
out_specs=P(),
)
Three moving parts:
-
Batch sharded along axis 0 (
P("data", None)): each device’strain_stepbody sees only its slice. - Per-shard grads computed locally — each device runs the same loss/backward.
-
All-reduce (
pmean): grads averaged across devices, so every device has the same global-average gradient. -
Single optimizer step per device — same params (replicated)
- same grads (after all-reduce) ⇒ params stay in lockstep.
Why all-reduce, not all-sum
Each shard’s loss is a per-shard mean (jnp.mean(...) over the
local examples). The global loss is the average over all shards.
Same logic for the gradient: pmean divides by the number of
devices, so the post-all-reduce grad equals the gradient that a
single-device run on the global batch would have produced.
Use psum if your local loss was a sum (not a mean) — but
that’s rare; means are easier to compose.
Single-device simulation
With one device, we simulate via a Python for loop. Each
iteration corresponds to “one device’s body of the shard_map.”
At the end, we reduce the per-shard losses + grads in Python:
chunk = x.shape[0] // num_devices
per_shard_grads = []
per_shard_losses = []
for i in range(num_devices):
x_shard = jax.lax.dynamic_slice_in_dim(x, i*chunk, chunk, axis=0)
y_shard = jax.lax.dynamic_slice_in_dim(y, i*chunk, chunk, axis=0)
loss, grads = nnx.value_and_grad(loss_fn)(model, x_shard, y_shard)
per_shard_losses.append(loss)
per_shard_grads.append(grads)
mean_loss = jnp.mean(jnp.stack(per_shard_losses))
avg_grads = jax.tree_util.tree_map(
lambda *xs: sum(xs) / float(num_devices), *per_shard_grads,
)
optimizer.update(model, avg_grads)
Each iteration is a “device” in the metaphor. The Python mean
of grads is the simulation of pmean. The single
optimizer.update matches the production “one update per
replica” rule.
Why init the optimizer ONCE
The simulation runs one training step total. Build the model
and the optimizer once at the top, then run the per-shard loop
inside, then call optimizer.update once at the end.
Building the optimizer inside the loop would (a) waste compute and (b) reset optax’s internal state (momentum, step counter) every iteration. The exact same caveat applies in production: one optimizer per replica, lasting the whole training job.
Common pitfalls
-
Updating per shard: calling
optimizer.update(model, grads_i)inside the loop appliesn_devseparate updates — the model drifts away from what data parallelism is supposed to compute. -
Summing instead of averaging grads: produces an effective
learning rate
n_devtimes what you set. Easy to misdiagnose as instability. -
Forgetting
wrt=nnx.Paramwhen constructing thennx.Optimizer— required since Flax 0.11. -
Asymmetric shard sizes: this problem assumes
batch % num_devices == 0. In production, padding fixes the remainder.
Problem
Implement data_parallel_step(seed, x_batch, y_batch, num_devices, lr):
-
Cast
seed,num_devicesto int;lrto float. -
Build
model = nnx.Linear(x_batch.shape[-1], y_batch.shape[-1], rngs=nnx.Rngs(seed))andoptimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param). -
Define
loss_fn(model, x, y) -> jnp.mean((model(x) - y) ** 2). -
Loop
i in range(num_devices):-
Slice the i-th
chunk = batch // num_devicesofx_batch,y_batchalong axis 0. -
Compute
loss_i, grads_i = nnx.value_and_grad(loss_fn)(model, x_shard, y_shard). - Append both to lists.
-
Slice the i-th
-
mean_loss = jnp.mean(jnp.stack(per_shard_losses)). -
avg_grads = jax.tree_util.tree_map(lambda *xs: sum(xs)/float(num_devices), *per_shard_grads). -
optimizer.update(model, avg_grads). -
Return
jnp.array([float(mean_loss)])— 1-D(1,).
Inputs:
-
seed: float (cast to int). -
x_batch: 2-D(N, D_in), divisible bynum_devices. -
y_batch: 2-D(N, D_out), same N. -
num_devices: float (cast to int). -
lr: float.
Output: 1-D (1,) — [mean_loss_across_shards].
Hints
Sign in to attempt this problem and view the solution.