hard primitives

NNX FSDP-Style Step

Why this matters

Plain data parallelism replicates the whole model on every device: 8 devices = 8× the parameter memory. For a 70B-parameter LLM, that’s 140 GB per device on bf16 — only the latest H100s and beyond fit it, and you’ve left zero room for activations or KV cache.

Fully Sharded Data Parallel (FSDP), also called ZeRO-3 in DeepSpeed lingo, fixes this: each device holds only a slice of every parameter. The full param is reconstructed on demand by an all-gather just before the layer needs it, then immediately freed after the matmul.

Result: parameter memory per device is total_params / num_shards instead of total_params. For a 70B model on 8 devices, that’s ~17.5 GB per device — fits on an A100, with room to spare.

The FSDP recipe (production)

from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map

mesh = Mesh(devices, axis_names=("fsdp",))

# Params sharded along an axis ("fsdp"-direction).
sharding = jax.tree_util.tree_map(lambda _: P("fsdp", *([None] * (rank - 1))), state_template)

@shard_map(mesh=mesh, in_specs=..., out_specs=P())
def step(model, optimizer, x, y):
    # Inside each device, params are LOCAL slices.
    # Forward: all_gather to materialize each layer's full kernel.
    gathered_params = jax.tree_util.tree_map(
        lambda p: jax.lax.all_gather(p, axis_name="fsdp"),
        local_params,
    )
    # ... forward + backward computes per-device grad pieces ...
    # reduce_scatter to keep only this shard's grad slice.
    local_grads = jax.lax.psum_scatter(
        full_grads, axis_name="fsdp", scatter_dimension=0,
    )
    # Optimizer update applied LOCALLY: each shard updates its slice.
    optimizer.update(model, local_grads)
    return loss

Three collectives:

  1. All-gather before each layer’s forward: (slice,) ⇒ (full,). Free the gathered tensor after the matmul to bound memory.
  2. Reduce-scatter at backward: each shard’s grad is summed (across devices) AND scattered, so each device keeps only its slice.
  3. No optimizer-state replication: optimizer state (momentum, Adam moments) is also sharded, same layout as params. Saves 2-3× again over plain data parallelism.

Why we don’t simulate every collective

The all-gather + reduce-scatter dance is purely a memory-shuffling optimization. The observable behavior of an FSDP step — the update applied to params — is identical to an un-sharded data- parallel step on the same global batch. That’s the whole correctness invariant: FSDP doesn’t change what you compute, only where each piece lives at each instant.

So the single-device simulation just runs ONE normal training step:

model = nnx.Linear(in_dim, out_dim, rngs=nnx.Rngs(seed))
optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param)
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
optimizer.update(model, grads)

This produces the same loss any FSDP-instrumented run would compute on this (x, y) pair. The interesting properties of FSDP — peak memory bound, communication overlap, optimizer-state sharding — are runtime ones that don’t show up in the numerics.

Memory math

For an nnx.Linear(in_dim, out_dim) with total = (in_dim+1)*out_dim params, FSDP across num_shards devices reduces per-device param memory to total / num_shards. Optimizer state for SGD is the same size as the params — sharded too. For Adam, optimizer state is the params (m and v moments) — also sharded.

The peak memory during forward includes one gathered layer’s full params (since you all-gather just before the matmul). For transformer blocks, the dominant term is the largest single layer, not the sum.

Why FSDP is the default for big models

  • Single-knob: just shard everything.
  • No need to manually split kernels (vs tensor parallelism).
  • Works with arbitrary model code, no rewrite.
  • Throughput is comparable to data parallelism if the communication is overlapped with compute — modern frameworks (PyTorch FSDP, JAX shard_map + all_gather) handle this.

The downside: you pay extra communication. Tensor parallelism avoids the all-gather (each device computes locally on its kernel slice), but constrains your model and parallel-axis sizes more.

Common pitfalls

  • Confusing FSDP with TP: TP shards the kernel for compute (each device computes a slice of the output). FSDP shards the kernel for memory (each device gathers the full kernel for compute, then frees it).
  • Forgetting to shard optimizer state: half the savings come from sharding m and v. PyTorch’s FullyShardedDataParallel does this automatically; in JAX you wire it up explicitly.
  • Asymmetric slice sizes: param dims must divide num_shards. Production handles padding inside the all-gather.
  • Building optimizer per step: same caveat as data-parallel — build once, reuse, otherwise momentum/step state resets.

Problem

Implement fsdp_step(seed, x, y, num_shards, lr):

  1. Cast seed, num_shards to int; lr to float.
  2. Build model = nnx.Linear(x.shape[-1], y.shape[-1], rngs=nnx.Rngs(seed)).
  3. Build optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param).
  4. Define loss_fn(model, x, y) -> jnp.mean((model(x) - y) ** 2).
  5. loss, grads = nnx.value_and_grad(loss_fn)(model, x, y).
  6. optimizer.update(model, grads).
  7. Return jnp.array([float(loss)]).

The num_shards argument is informational only in the simulation — it documents the layout the production system would use; the numerics don’t depend on it.

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D_in).
  • y: 2-D (N, D_out).
  • num_shards: float (informational).
  • lr: float.

Output: 1-D (1,)[loss] (pre-update loss for the step).

Hints

flax nnx sharding fsdp zero3

Sign in to attempt this problem and view the solution.