medium primitives

NNX Orbax Sharded Load

Why this matters

Saving a sharded checkpoint is half the job. Loading it back requires reconstructing the sharded layout — the on-disk shards have to land on the right devices, in the right slices, before the next training step can run.

Get this wrong and one of two things happens:

  • Each device loads the full param into HBM (wasting memory; you OOM on big models).
  • Each device loads the wrong slice of each param, silently corrupting your model.

The fix is to provide Orbax with a target sharding template: a pytree of PartitionSpecs describing where each leaf should live in the new mesh. Orbax then re-streams the on-disk shards, re-shards as needed, and places each slice directly on the target device.

The Orbax + NNX restore recipe (production)

import orbax.checkpoint as ocp
from flax import nnx
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

# 1. Build the target Mesh + sharding tree.
mesh = Mesh(devices, axis_names=("model",))
target_sharding = jax.tree_util.tree_map(
    lambda leaf: NamedSharding(mesh, P(None, "model")),  # for kernels
    state_template,
)

# 2. Build the per-leaf RestoreArgs with that sharding.
restore_args = jax.tree_util.tree_map(
    lambda sh: ocp.ArrayRestoreArgs(sharding=sh),
    target_sharding,
)

# 3. Restore. Each device receives only its slice; no host gather.
handler = ocp.PyTreeCheckpointHandler()
ckptr = ocp.Checkpointer(handler)
state_loaded = ckptr.restore(
    "/tmp/dist_ckpt", restore_args=restore_args,
)

# 4. Glue back into a Module.
model = nnx.merge(graphdef, state_loaded)

Three notable points:

  • graphdef must come from somewhere. Either you saved it alongside the checkpoint (Orbax can do this) or you re-create the model architecture in Python and pull out a fresh graphdef. The graph is Python structure, not an array — it doesn’t need to be sharded.
  • target_sharding can differ from the save-time sharding. Orbax re-shards on the fly. Save on 8 hosts, restore on 16 — Orbax splits each saved shard into 2.
  • No host gather. The whole point: on a distributed system, no single host should ever hold the full model. Restore writes directly to the target devices.

Idempotence: the test trick

On a single device, we can’t actually run Orbax’s distributed restore. But we can test the idempotence property that any correct save/load round-trip must satisfy:

save(model_seed_K) ; load(...) → model with the SAME params

Since nnx.Linear‘s init is deterministic given the seed, we simulate “load” by building a fresh nnx.Linear with the same seed. This produces byte-identical params to the “saved” model — just like a real Orbax restore would.

saver = nnx.Linear(in_dim, features, rngs=nnx.Rngs(seed))    # 'save'
loader = nnx.Linear(in_dim, features, rngs=nnx.Rngs(seed))   # 'load'
assert jnp.allclose(saver.kernel.value, loader.kernel.value) # idempotent

The function returns loader(x) — the forward pass through the “loaded” model. The test verifies this matches the saver’s forward, which is what a real round-trip would also satisfy.

Why we run forward, not just compare params

A bare-numbers comparison is fragile (atol/rtol issues across backends, etc.). A forward pass is a stronger check: even tiny param mismatches show up as different outputs. By demanding the function reconstruct and use the model, we test the full “restore + use” round-trip you’d run in production.

What to do when seeds aren’t enough

In production, your params come from a training run, not from init. The save/load round-trip’s correctness depends entirely on Orbax’s serialization fidelity — the seed-trick above is just our sandbox approximation.

For real validation, save a model, restore it, and compare a forward pass on a fixed input. If they don’t match bit-for-bit, something in your save/restore pipeline is wrong (most commonly: dtype mismatch, missing leaf, or a sharding spec that doesn’t match the actual mesh).

Common pitfalls

  • Forgetting restore_args: Orbax then loads everything to the calling host (replicated). On a 70B model that’s an OOM.
  • Mismatched graphdef: if the save-time model and the restore-time model have different architectures (different attribute names, layer counts, etc.), nnx.merge raises.
  • Hard-coding the save mesh: production code should let the target mesh come from the current job, not from the save-time job. Orbax handles re-sharding; trust it.

Problem

Implement orbax_load_simulate(seed, x, features):

  1. Cast seed, features to ints. Use x.shape[-1] as the input dim.
  2. Build a loader = nnx.Linear(in_dim, features, rngs=nnx.Rngs(seed)) — this stands in for the nnx.merge(graphdef, restored_state) step. Same seed ⇒ same params ⇒ idempotent restore.
  3. Run out = loader(x).
  4. Return out.reshape(-1) flattened.

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, in_dim).
  • features: float (cast to int).

Output: 1-D, length N * features.

Hints

flax nnx orbax checkpoint sharding load

Sign in to attempt this problem and view the solution.