We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX FSDP-Style Step
Why this matters
Plain data parallelism replicates the whole model on every device: 8 devices = 8× the parameter memory. For a 70B-parameter LLM, that’s 140 GB per device on bf16 — only the latest H100s and beyond fit it, and you’ve left zero room for activations or KV cache.
Fully Sharded Data Parallel (FSDP), also called ZeRO-3 in DeepSpeed lingo, fixes this: each device holds only a slice of every parameter. The full param is reconstructed on demand by an all-gather just before the layer needs it, then immediately freed after the matmul.
Result: parameter memory per device is total_params / num_shards
instead of total_params. For a 70B model on 8 devices, that’s
~17.5 GB per device — fits on an A100, with room to spare.
The FSDP recipe (production)
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
mesh = Mesh(devices, axis_names=("fsdp",))
# Params sharded along an axis ("fsdp"-direction).
sharding = jax.tree_util.tree_map(lambda _: P("fsdp", *([None] * (rank - 1))), state_template)
@shard_map(mesh=mesh, in_specs=..., out_specs=P())
def step(model, optimizer, x, y):
# Inside each device, params are LOCAL slices.
# Forward: all_gather to materialize each layer's full kernel.
gathered_params = jax.tree_util.tree_map(
lambda p: jax.lax.all_gather(p, axis_name="fsdp"),
local_params,
)
# ... forward + backward computes per-device grad pieces ...
# reduce_scatter to keep only this shard's grad slice.
local_grads = jax.lax.psum_scatter(
full_grads, axis_name="fsdp", scatter_dimension=0,
)
# Optimizer update applied LOCALLY: each shard updates its slice.
optimizer.update(model, local_grads)
return loss
Three collectives:
-
All-gather before each layer’s forward:
(slice,) ⇒ (full,). Free the gathered tensor after the matmul to bound memory. - Reduce-scatter at backward: each shard’s grad is summed (across devices) AND scattered, so each device keeps only its slice.
- No optimizer-state replication: optimizer state (momentum, Adam moments) is also sharded, same layout as params. Saves 2-3× again over plain data parallelism.
Why we don’t simulate every collective
The all-gather + reduce-scatter dance is purely a memory-shuffling optimization. The observable behavior of an FSDP step — the update applied to params — is identical to an un-sharded data- parallel step on the same global batch. That’s the whole correctness invariant: FSDP doesn’t change what you compute, only where each piece lives at each instant.
So the single-device simulation just runs ONE normal training step:
model = nnx.Linear(in_dim, out_dim, rngs=nnx.Rngs(seed))
optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param)
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
optimizer.update(model, grads)
This produces the same loss any FSDP-instrumented run would compute
on this (x, y) pair. The interesting properties of FSDP — peak
memory bound, communication overlap, optimizer-state sharding — are
runtime ones that don’t show up in the numerics.
Memory math
For an nnx.Linear(in_dim, out_dim) with total = (in_dim+1)*out_dim
params, FSDP across num_shards devices reduces per-device param
memory to total / num_shards. Optimizer state for SGD is the same
size as the params — sharded too. For Adam, optimizer state is 2×
the params (m and v moments) — also sharded.
The peak memory during forward includes one gathered layer’s full params (since you all-gather just before the matmul). For transformer blocks, the dominant term is the largest single layer, not the sum.
Why FSDP is the default for big models
- Single-knob: just shard everything.
- No need to manually split kernels (vs tensor parallelism).
- Works with arbitrary model code, no rewrite.
-
Throughput is comparable to data parallelism if the
communication is overlapped with compute — modern frameworks
(PyTorch FSDP, JAX
shard_map+all_gather) handle this.
The downside: you pay extra communication. Tensor parallelism avoids the all-gather (each device computes locally on its kernel slice), but constrains your model and parallel-axis sizes more.
Common pitfalls
- Confusing FSDP with TP: TP shards the kernel for compute (each device computes a slice of the output). FSDP shards the kernel for memory (each device gathers the full kernel for compute, then frees it).
-
Forgetting to shard optimizer state: half the savings
come from sharding
mandv. PyTorch’sFullyShardedDataParalleldoes this automatically; in JAX you wire it up explicitly. -
Asymmetric slice sizes: param dims must divide
num_shards. Production handles padding inside the all-gather. - Building optimizer per step: same caveat as data-parallel — build once, reuse, otherwise momentum/step state resets.
Problem
Implement fsdp_step(seed, x, y, num_shards, lr):
-
Cast
seed,num_shardsto int;lrto float. -
Build
model = nnx.Linear(x.shape[-1], y.shape[-1], rngs=nnx.Rngs(seed)). -
Build
optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param). -
Define
loss_fn(model, x, y) -> jnp.mean((model(x) - y) ** 2). -
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y). -
optimizer.update(model, grads). -
Return
jnp.array([float(loss)]).
The num_shards argument is informational only in the
simulation — it documents the layout the production system would
use; the numerics don’t depend on it.
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D_in). -
y: 2-D(N, D_out). -
num_shards: float (informational). -
lr: float.
Output: 1-D (1,) — [loss] (pre-update loss for the step).
Hints
Sign in to attempt this problem and view the solution.