hard primitives

NNX Distributed Train Step

Why this matters

This problem composes everything from Phase 9 into one capstone: a single distributed training step that ties together sharding, per-shard forward/backward, all-reduce, and a single optimizer update.

The shape is the same as a single-device step you’ve already written:

loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
optimizer.update(model, grads)

The difference: the gradients are an average over all devices‘ grads, not just one device’s grad. Everything else in the step flows from that.

The full pattern (production)

from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map

mesh = Mesh(devices, axis_names=("data",))

@jax.jit
@shard_map(mesh=mesh, in_specs=..., out_specs=P())
def train_step(model, optimizer, x, y):
    # x, y are sharded along axis 0; each device sees its slice.
    def loss_fn(model, x, y):
        return jnp.mean((model(x) - y) ** 2)

    loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
    # All-reduce mean across the data mesh axis.
    grads = jax.tree_util.tree_map(
        lambda g: jax.lax.pmean(g, axis_name="data"), grads,
    )
    optimizer.update(model, grads)        # one update per device
    # Mean loss across devices for logging.
    loss = jax.lax.pmean(loss, axis_name="data")
    return loss

Three steps:

  1. Per-shard forward+backward: each device’s train_step body computes a local loss and grad on its slice of (x, y).
  2. All-reduce mean: pmean averages grads (and optionally loss) across devices. After this, every device’s grad is the global-batch gradient — same as a single-device run on the full batch.
  3. Single optimizer update per device: same params + same grads ⇒ params stay in lockstep.

The reduction order matters

pmean is logically psum / num_devices. Two equivalent forms:

# Form 1: psum then divide.
summed = jax.lax.psum(g, axis_name="data")
avg = summed / num_devices

# Form 2: pmean directly.
avg = jax.lax.pmean(g, axis_name="data")

Both produce identical results. The simulation here uses Form 1 explicitly to make the SUM-then-DIVIDE structure visible, which is what most “all-reduce” implementations actually do at the hardware level (NCCL, RCCL, etc.).

Single-device simulation

With one device, simulate via a Python loop. Each iteration is a different “device” body of the shard_map:

chunk = x.shape[0] // num_devices
per_shard_grads = []
per_shard_losses = []
for i in range(num_devices):
    x_shard = jax.lax.dynamic_slice_in_dim(x, i*chunk, chunk, axis=0)
    y_shard = jax.lax.dynamic_slice_in_dim(y, i*chunk, chunk, axis=0)
    loss, grads = nnx.value_and_grad(loss_fn)(model, x_shard, y_shard)
    per_shard_losses.append(loss)
    per_shard_grads.append(grads)

# Sum-then-divide all-reduce.
summed = jax.tree_util.tree_map(lambda *xs: sum(xs), *per_shard_grads)
avg = jax.tree_util.tree_map(lambda g: g / float(num_devices), summed)

optimizer.update(model, avg)

final_loss = jnp.mean(jnp.stack(per_shard_losses))

The for loop is the Python stand-in for “this body runs on num_devices devices in parallel.” The two tree_maps explicitly decompose the all-reduce into sum + divide.

Building the model and optimizer outside the loop

Same caveat as nnx-data-parallel-step. Build once, run the per-shard loop, then do one optimizer.update after the reduction. Building inside the loop resets optax state every iteration; calling optimizer.update inside the loop applies multiple updates per logical step.

What this DOESN’T cover

Two patterns this problem deliberately doesn’t include:

  • Optimizer state sharding (FSDP-style): see nnx-fsdp-style. The training step here keeps optimizer state replicated on every device — the simplest case.
  • Tensor parallelism: see nnx-tensor-parallel-linear. The kernel here is replicated; the data is sharded. TP would shard the kernel as well, with collectives (psum for row-parallel, all_gather for column-parallel) inside the layer.

Real production training combines all three (DP + FSDP + TP). Each is a separate compositional layer over the basic step you write here.

Common pitfalls

  • optimizer.update inside the loop: applies num_devices updates per step. The model drifts from what data parallelism should produce.
  • Summing without dividing: produces an effective LR n_dev times what you set. Easy to misdiagnose as instability.
  • Forgetting wrt=nnx.Param: the nnx.Optimizer needs to know which Variables are trainable params. Required since Flax 0.11.
  • Returning the loss from the WRONG shard: average across shards; don’t return shard 0’s loss as if it represented the whole batch.

Problem

Implement distributed_train_step(seed, x_batch, y_batch, num_devices, lr):

  1. Cast seed, num_devices to int; lr to float.
  2. Build model = nnx.Linear(x_batch.shape[-1], y_batch.shape[-1], rngs=nnx.Rngs(seed)) and optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param) outside the loop.
  3. Define loss_fn(model, x, y) -> jnp.mean((model(x) - y) ** 2).
  4. Loop i in range(num_devices): slice the i-th chunk = batch // num_devices of x_batch, y_batch along axis 0; compute nnx.value_and_grad(loss_fn)(model, x_shard, y_shard); collect loss and grads.
  5. Sum-then-divide: summed = tree_map(lambda *xs: sum(xs), *grads_list), then avg = tree_map(lambda g: g / float(num_devices), summed).
  6. optimizer.update(model, avg).
  7. final_loss = jnp.mean(jnp.stack(per_shard_losses)).
  8. Return jnp.array([float(final_loss)]).

Inputs:

  • seed: float (cast to int).
  • x_batch: 2-D (N, D_in); divisible by num_devices.
  • y_batch: 2-D (N, D_out).
  • num_devices: float (cast to int).
  • lr: float.

Output: 1-D (1,)[mean_loss_across_shards].

Hints

flax nnx sharding distributed training all-reduce

Sign in to attempt this problem and view the solution.