We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Orbax Sharded Save
Why this matters
A 70B-parameter LLM doesn’t just need sharded compute — it needs
a sharded checkpoint. You can’t gather() everything to a
single host (it doesn’t fit in host memory either!), and even if
you could, serializing 140 GB through one process and one
filesystem is a 30-minute checkpoint.
The fix is distributed checkpointing: every replica writes its
own shard concurrently. Total wall time scales as
O(P / num_shards) instead of O(P), with zero coordination
bottleneck.
Orbax — Google’s official JAX checkpoint library — implements this out of the box for NNX modules.
The Orbax + NNX recipe (production)
import orbax.checkpoint as ocp
from flax import nnx
# 1. Split the model so we have an array-side pytree.
graphdef, state = nnx.split(model)
# 2. Build the per-leaf SaveArgs with sharding info.
save_args = jax.tree_util.tree_map(
lambda arr: ocp.SaveArgs(), # Orbax inspects arr.sharding
state,
)
# 3. Save.
handler = ocp.PyTreeCheckpointHandler()
ckptr = ocp.Checkpointer(handler)
ckptr.save("/tmp/dist_ckpt", state, save_args=save_args)
The on-disk layout looks roughly like:
/tmp/dist_ckpt/
_METADATA <- tree structure + dtype + sharding
kernel/
shard_0.array <- replica 0's slice of kernel
shard_1.array <- replica 1's slice
...
shard_{N-1}.array
bias/
shard_0.array
...
Each shard file is ~leaf_size / num_shards floats — the slice
that that replica held in memory. No cross-replica gather is
needed; each replica writes what’s on its devices, in parallel.
Why even sharding matters
-
I/O parallelism: each shard writes a separate file
simultaneously — checkpoint time scales as
O(P / N_shards)instead ofO(P). - Memory ceiling: no single host needs to hold the full model.
- Resilience: lose one shard’s disk, you can re-stream just that shard from the original distributed model.
Restore semantics
On restore, you provide a target sharding tree so Orbax can redistribute shards onto the (possibly different) destination mesh:
target = jax.tree_util.tree_map(
lambda spec: ocp.ArrayRestoreArgs(sharding=spec),
target_sharding_tree,
)
state_loaded = ckptr.restore("/tmp/dist_ckpt", restore_args=target)
model = nnx.merge(graphdef, state_loaded)
Common case: save on 8 hosts, restore on 16 — Orbax re-shards on the fly, splitting each saved shard into two. As long as each leaf’s sharding is recorded in the on-disk metadata, the restore is dimension-aware.
Why we don’t actually call Orbax
Distributed Orbax requires:
-
A real multi-host JAX setup (or
jax.distributed.initialize). - A real filesystem with concurrent-write semantics.
- Coordinated processes.
None of these exist in the sandbox. Instead, we compute the checkpoint accounting: total parameter count plus the per-shard size, which is what an engineer planning a sharded save needs to know up front (file count, peak per-host I/O, recovery costs).
Memory math for nnx.Linear
For an nnx.Linear(in_dim, out_dim):
-
kernel.size = in_dim * out_dim -
bias.size = out_dim -
total = (in_dim + 1) * out_dim
Per shard: total / num_shards. We use float division so the
user sees the exact per-shard size, even when uneven. (In real
Orbax, uneven shards are handled with padding under the hood.)
Common pitfalls
-
Calling
ckptr.savebeforennx.split: Orbax operates on pytrees of arrays, not Modules. Always split first. -
Ignoring
save_args: defaults are usually fine but for mixed-precision (e.g. saving as bf16) or compression you need explicit args. -
Restoring with the wrong sharding template: Orbax can
re-shard, but only when the target spec is correct. Mismatch
with the actual
Meshshape is a runtime error.
Problem
Implement orbax_save_count(seed, x, features, num_shards):
-
Cast
seed,features,num_shardsto ints. Usex.shape[-1]as the input dim. -
Build
model = nnx.Linear(in_dim, features, rngs=nnx.Rngs(seed)). -
total = kernel.size + bias.size. For Dense, that’s(in_dim + 1) * features. -
per_shard = total / num_shards(float division). -
Return
jnp.array([float(total), per_shard])— a 1-D(2,).
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, in_dim). -
features: float (cast to int). -
num_shards: float (cast to int).
Output: 1-D (2,) — [total_params, per_shard_size].
Hints
Sign in to attempt this problem and view the solution.