hard primitives

NNX Data-Parallel Step

Why this matters

Data parallelism is the bread-and-butter of multi-device training: every device holds a full copy of the model, but only a shard of the batch. Each device computes its forward and backward independently; an all-reduce across devices averages the gradients before the optimizer step.

Result: linear throughput scaling. 8 devices = 8× the batch size = 8× the effective compute per step. This is the cheapest, simplest scaling lever — and it’s what you reach for first when your model already fits on one device but you want more throughput.

The production recipe (multi-device)

from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map

mesh = Mesh(devices, axis_names=("data",))

@jax.jit
def train_step(model, optimizer, x, y):
    # x, y are sharded along axis 0 by P("data", None).
    def loss_fn(model, x, y):
        return jnp.mean((model(x) - y) ** 2)
    loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
    # All-reduce mean across the data mesh axis.
    grads = jax.tree_util.tree_map(
        lambda g: jax.lax.pmean(g, axis_name="data"), grads,
    )
    optimizer.update(model, grads)
    return loss

sharded_step = shard_map(
    train_step, mesh,
    in_specs=(P(), P(), P("data", None), P("data", None)),
    out_specs=P(),
)

Three moving parts:

  1. Batch sharded along axis 0 (P("data", None)): each device’s train_step body sees only its slice.
  2. Per-shard grads computed locally — each device runs the same loss/backward.
  3. All-reduce (pmean): grads averaged across devices, so every device has the same global-average gradient.
  4. Single optimizer step per device — same params (replicated)
    • same grads (after all-reduce) ⇒ params stay in lockstep.

Why all-reduce, not all-sum

Each shard’s loss is a per-shard mean (jnp.mean(...) over the local examples). The global loss is the average over all shards. Same logic for the gradient: pmean divides by the number of devices, so the post-all-reduce grad equals the gradient that a single-device run on the global batch would have produced.

Use psum if your local loss was a sum (not a mean) — but that’s rare; means are easier to compose.

Single-device simulation

With one device, we simulate via a Python for loop. Each iteration corresponds to “one device’s body of the shard_map.” At the end, we reduce the per-shard losses + grads in Python:

chunk = x.shape[0] // num_devices
per_shard_grads = []
per_shard_losses = []
for i in range(num_devices):
    x_shard = jax.lax.dynamic_slice_in_dim(x, i*chunk, chunk, axis=0)
    y_shard = jax.lax.dynamic_slice_in_dim(y, i*chunk, chunk, axis=0)
    loss, grads = nnx.value_and_grad(loss_fn)(model, x_shard, y_shard)
    per_shard_losses.append(loss)
    per_shard_grads.append(grads)

mean_loss = jnp.mean(jnp.stack(per_shard_losses))
avg_grads = jax.tree_util.tree_map(
    lambda *xs: sum(xs) / float(num_devices), *per_shard_grads,
)
optimizer.update(model, avg_grads)

Each iteration is a “device” in the metaphor. The Python mean of grads is the simulation of pmean. The single optimizer.update matches the production “one update per replica” rule.

Why init the optimizer ONCE

The simulation runs one training step total. Build the model and the optimizer once at the top, then run the per-shard loop inside, then call optimizer.update once at the end.

Building the optimizer inside the loop would (a) waste compute and (b) reset optax’s internal state (momentum, step counter) every iteration. The exact same caveat applies in production: one optimizer per replica, lasting the whole training job.

Common pitfalls

  • Updating per shard: calling optimizer.update(model, grads_i) inside the loop applies n_dev separate updates — the model drifts away from what data parallelism is supposed to compute.
  • Summing instead of averaging grads: produces an effective learning rate n_dev times what you set. Easy to misdiagnose as instability.
  • Forgetting wrt=nnx.Param when constructing the nnx.Optimizer — required since Flax 0.11.
  • Asymmetric shard sizes: this problem assumes batch % num_devices == 0. In production, padding fixes the remainder.

Problem

Implement data_parallel_step(seed, x_batch, y_batch, num_devices, lr):

  1. Cast seed, num_devices to int; lr to float.
  2. Build model = nnx.Linear(x_batch.shape[-1], y_batch.shape[-1], rngs=nnx.Rngs(seed)) and optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param).
  3. Define loss_fn(model, x, y) -> jnp.mean((model(x) - y) ** 2).
  4. Loop i in range(num_devices):
    • Slice the i-th chunk = batch // num_devices of x_batch, y_batch along axis 0.
    • Compute loss_i, grads_i = nnx.value_and_grad(loss_fn)(model, x_shard, y_shard).
    • Append both to lists.
  5. mean_loss = jnp.mean(jnp.stack(per_shard_losses)).
  6. avg_grads = jax.tree_util.tree_map(lambda *xs: sum(xs)/float(num_devices), *per_shard_grads).
  7. optimizer.update(model, avg_grads).
  8. Return jnp.array([float(mean_loss)]) — 1-D (1,).

Inputs:

  • seed: float (cast to int).
  • x_batch: 2-D (N, D_in), divisible by num_devices.
  • y_batch: 2-D (N, D_out), same N.
  • num_devices: float (cast to int).
  • lr: float.

Output: 1-D (1,)[mean_loss_across_shards].

Hints

flax nnx sharding data-parallel training

Sign in to attempt this problem and view the solution.