We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Orbax Sharded Load
Why this matters
Saving a sharded checkpoint is half the job. Loading it back requires reconstructing the sharded layout — the on-disk shards have to land on the right devices, in the right slices, before the next training step can run.
Get this wrong and one of two things happens:
- Each device loads the full param into HBM (wasting memory; you OOM on big models).
- Each device loads the wrong slice of each param, silently corrupting your model.
The fix is to provide Orbax with a target sharding template:
a pytree of PartitionSpecs describing where each leaf should
live in the new mesh. Orbax then re-streams the on-disk shards,
re-shards as needed, and places each slice directly on the
target device.
The Orbax + NNX restore recipe (production)
import orbax.checkpoint as ocp
from flax import nnx
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
# 1. Build the target Mesh + sharding tree.
mesh = Mesh(devices, axis_names=("model",))
target_sharding = jax.tree_util.tree_map(
lambda leaf: NamedSharding(mesh, P(None, "model")), # for kernels
state_template,
)
# 2. Build the per-leaf RestoreArgs with that sharding.
restore_args = jax.tree_util.tree_map(
lambda sh: ocp.ArrayRestoreArgs(sharding=sh),
target_sharding,
)
# 3. Restore. Each device receives only its slice; no host gather.
handler = ocp.PyTreeCheckpointHandler()
ckptr = ocp.Checkpointer(handler)
state_loaded = ckptr.restore(
"/tmp/dist_ckpt", restore_args=restore_args,
)
# 4. Glue back into a Module.
model = nnx.merge(graphdef, state_loaded)
Three notable points:
-
graphdefmust come from somewhere. Either you saved it alongside the checkpoint (Orbax can do this) or you re-create the model architecture in Python and pull out a freshgraphdef. The graph is Python structure, not an array — it doesn’t need to be sharded. -
target_shardingcan differ from the save-time sharding. Orbax re-shards on the fly. Save on 8 hosts, restore on 16 — Orbax splits each saved shard into 2. - No host gather. The whole point: on a distributed system, no single host should ever hold the full model. Restore writes directly to the target devices.
Idempotence: the test trick
On a single device, we can’t actually run Orbax’s distributed restore. But we can test the idempotence property that any correct save/load round-trip must satisfy:
save(model_seed_K) ; load(...) → model with the SAME params
Since nnx.Linear‘s init is deterministic given the seed,
we simulate “load” by building a fresh nnx.Linear with the
same seed. This produces byte-identical params to the
“saved” model — just like a real Orbax restore would.
saver = nnx.Linear(in_dim, features, rngs=nnx.Rngs(seed)) # 'save'
loader = nnx.Linear(in_dim, features, rngs=nnx.Rngs(seed)) # 'load'
assert jnp.allclose(saver.kernel.value, loader.kernel.value) # idempotent
The function returns loader(x) — the forward pass through the
“loaded” model. The test verifies this matches the saver’s
forward, which is what a real round-trip would also satisfy.
Why we run forward, not just compare params
A bare-numbers comparison is fragile (atol/rtol issues across backends, etc.). A forward pass is a stronger check: even tiny param mismatches show up as different outputs. By demanding the function reconstruct and use the model, we test the full “restore + use” round-trip you’d run in production.
What to do when seeds aren’t enough
In production, your params come from a training run, not from init. The save/load round-trip’s correctness depends entirely on Orbax’s serialization fidelity — the seed-trick above is just our sandbox approximation.
For real validation, save a model, restore it, and compare a forward pass on a fixed input. If they don’t match bit-for-bit, something in your save/restore pipeline is wrong (most commonly: dtype mismatch, missing leaf, or a sharding spec that doesn’t match the actual mesh).
Common pitfalls
-
Forgetting
restore_args: Orbax then loads everything to the calling host (replicated). On a 70B model that’s an OOM. -
Mismatched
graphdef: if the save-time model and the restore-time model have different architectures (different attribute names, layer counts, etc.),nnx.mergeraises. - Hard-coding the save mesh: production code should let the target mesh come from the current job, not from the save-time job. Orbax handles re-sharding; trust it.
Problem
Implement orbax_load_simulate(seed, x, features):
-
Cast
seed,featuresto ints. Usex.shape[-1]as the input dim. -
Build a
loader = nnx.Linear(in_dim, features, rngs=nnx.Rngs(seed))— this stands in for thennx.merge(graphdef, restored_state)step. Same seed ⇒ same params ⇒ idempotent restore. -
Run
out = loader(x). -
Return
out.reshape(-1)flattened.
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, in_dim). -
features: float (cast to int).
Output: 1-D, length N * features.
Hints
Sign in to attempt this problem and view the solution.