We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Distributed Train Step
Why this matters
This problem composes everything from Phase 9 into one capstone: a single distributed training step that ties together sharding, per-shard forward/backward, all-reduce, and a single optimizer update.
The shape is the same as a single-device step you’ve already written:
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
optimizer.update(model, grads)
The difference: the gradients are an average over all devices‘ grads, not just one device’s grad. Everything else in the step flows from that.
The full pattern (production)
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = Mesh(devices, axis_names=("data",))
@jax.jit
@shard_map(mesh=mesh, in_specs=..., out_specs=P())
def train_step(model, optimizer, x, y):
# x, y are sharded along axis 0; each device sees its slice.
def loss_fn(model, x, y):
return jnp.mean((model(x) - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
# All-reduce mean across the data mesh axis.
grads = jax.tree_util.tree_map(
lambda g: jax.lax.pmean(g, axis_name="data"), grads,
)
optimizer.update(model, grads) # one update per device
# Mean loss across devices for logging.
loss = jax.lax.pmean(loss, axis_name="data")
return loss
Three steps:
-
Per-shard forward+backward: each device’s
train_stepbody computes a local loss and grad on its slice of(x, y). -
All-reduce mean:
pmeanaverages grads (and optionally loss) across devices. After this, every device’s grad is the global-batch gradient — same as a single-device run on the full batch. - Single optimizer update per device: same params + same grads ⇒ params stay in lockstep.
The reduction order matters
pmean is logically psum / num_devices. Two equivalent forms:
# Form 1: psum then divide.
summed = jax.lax.psum(g, axis_name="data")
avg = summed / num_devices
# Form 2: pmean directly.
avg = jax.lax.pmean(g, axis_name="data")
Both produce identical results. The simulation here uses Form 1 explicitly to make the SUM-then-DIVIDE structure visible, which is what most “all-reduce” implementations actually do at the hardware level (NCCL, RCCL, etc.).
Single-device simulation
With one device, simulate via a Python loop. Each iteration is a
different “device” body of the shard_map:
chunk = x.shape[0] // num_devices
per_shard_grads = []
per_shard_losses = []
for i in range(num_devices):
x_shard = jax.lax.dynamic_slice_in_dim(x, i*chunk, chunk, axis=0)
y_shard = jax.lax.dynamic_slice_in_dim(y, i*chunk, chunk, axis=0)
loss, grads = nnx.value_and_grad(loss_fn)(model, x_shard, y_shard)
per_shard_losses.append(loss)
per_shard_grads.append(grads)
# Sum-then-divide all-reduce.
summed = jax.tree_util.tree_map(lambda *xs: sum(xs), *per_shard_grads)
avg = jax.tree_util.tree_map(lambda g: g / float(num_devices), summed)
optimizer.update(model, avg)
final_loss = jnp.mean(jnp.stack(per_shard_losses))
The for loop is the Python stand-in for “this body runs on
num_devices devices in parallel.” The two tree_maps explicitly
decompose the all-reduce into sum + divide.
Building the model and optimizer outside the loop
Same caveat as nnx-data-parallel-step. Build once, run the
per-shard loop, then do one optimizer.update after the
reduction. Building inside the loop resets optax state every
iteration; calling optimizer.update inside the loop applies
multiple updates per logical step.
What this DOESN’T cover
Two patterns this problem deliberately doesn’t include:
-
Optimizer state sharding (FSDP-style): see
nnx-fsdp-style. The training step here keeps optimizer state replicated on every device — the simplest case. -
Tensor parallelism: see
nnx-tensor-parallel-linear. The kernel here is replicated; the data is sharded. TP would shard the kernel as well, with collectives (psumfor row-parallel,all_gatherfor column-parallel) inside the layer.
Real production training combines all three (DP + FSDP + TP). Each is a separate compositional layer over the basic step you write here.
Common pitfalls
-
optimizer.updateinside the loop: appliesnum_devicesupdates per step. The model drifts from what data parallelism should produce. -
Summing without dividing: produces an effective LR
n_devtimes what you set. Easy to misdiagnose as instability. -
Forgetting
wrt=nnx.Param: thennx.Optimizerneeds to know which Variables are trainable params. Required since Flax 0.11. - Returning the loss from the WRONG shard: average across shards; don’t return shard 0’s loss as if it represented the whole batch.
Problem
Implement distributed_train_step(seed, x_batch, y_batch, num_devices, lr):
-
Cast
seed,num_devicesto int;lrto float. -
Build
model = nnx.Linear(x_batch.shape[-1], y_batch.shape[-1], rngs=nnx.Rngs(seed))andoptimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param)outside the loop. -
Define
loss_fn(model, x, y) -> jnp.mean((model(x) - y) ** 2). -
Loop
i in range(num_devices): slice the i-thchunk = batch // num_devicesofx_batch,y_batchalong axis 0; computennx.value_and_grad(loss_fn)(model, x_shard, y_shard); collect loss and grads. -
Sum-then-divide:
summed = tree_map(lambda *xs: sum(xs), *grads_list), thenavg = tree_map(lambda g: g / float(num_devices), summed). -
optimizer.update(model, avg). -
final_loss = jnp.mean(jnp.stack(per_shard_losses)). -
Return
jnp.array([float(final_loss)]).
Inputs:
-
seed: float (cast to int). -
x_batch: 2-D(N, D_in); divisible bynum_devices. -
y_batch: 2-D(N, D_out). -
num_devices: float (cast to int). -
lr: float.
Output: 1-D (1,) — [mean_loss_across_shards].
Hints
Sign in to attempt this problem and view the solution.