medium primitives

NNX Orbax Sharded Save

Why this matters

A 70B-parameter LLM doesn’t just need sharded compute — it needs a sharded checkpoint. You can’t gather() everything to a single host (it doesn’t fit in host memory either!), and even if you could, serializing 140 GB through one process and one filesystem is a 30-minute checkpoint.

The fix is distributed checkpointing: every replica writes its own shard concurrently. Total wall time scales as O(P / num_shards) instead of O(P), with zero coordination bottleneck.

Orbax — Google’s official JAX checkpoint library — implements this out of the box for NNX modules.

The Orbax + NNX recipe (production)

import orbax.checkpoint as ocp
from flax import nnx

# 1. Split the model so we have an array-side pytree.
graphdef, state = nnx.split(model)

# 2. Build the per-leaf SaveArgs with sharding info.
save_args = jax.tree_util.tree_map(
    lambda arr: ocp.SaveArgs(),  # Orbax inspects arr.sharding
    state,
)

# 3. Save.
handler = ocp.PyTreeCheckpointHandler()
ckptr = ocp.Checkpointer(handler)
ckptr.save("/tmp/dist_ckpt", state, save_args=save_args)

The on-disk layout looks roughly like:

/tmp/dist_ckpt/
  _METADATA              <- tree structure + dtype + sharding
  kernel/
    shard_0.array        <- replica 0's slice of kernel
    shard_1.array        <- replica 1's slice
    ...
    shard_{N-1}.array
  bias/
    shard_0.array
    ...

Each shard file is ~leaf_size / num_shards floats — the slice that that replica held in memory. No cross-replica gather is needed; each replica writes what’s on its devices, in parallel.

Why even sharding matters

  • I/O parallelism: each shard writes a separate file simultaneously — checkpoint time scales as O(P / N_shards) instead of O(P).
  • Memory ceiling: no single host needs to hold the full model.
  • Resilience: lose one shard’s disk, you can re-stream just that shard from the original distributed model.

Restore semantics

On restore, you provide a target sharding tree so Orbax can redistribute shards onto the (possibly different) destination mesh:

target = jax.tree_util.tree_map(
    lambda spec: ocp.ArrayRestoreArgs(sharding=spec),
    target_sharding_tree,
)
state_loaded = ckptr.restore("/tmp/dist_ckpt", restore_args=target)
model = nnx.merge(graphdef, state_loaded)

Common case: save on 8 hosts, restore on 16 — Orbax re-shards on the fly, splitting each saved shard into two. As long as each leaf’s sharding is recorded in the on-disk metadata, the restore is dimension-aware.

Why we don’t actually call Orbax

Distributed Orbax requires:

  • A real multi-host JAX setup (or jax.distributed.initialize).
  • A real filesystem with concurrent-write semantics.
  • Coordinated processes.

None of these exist in the sandbox. Instead, we compute the checkpoint accounting: total parameter count plus the per-shard size, which is what an engineer planning a sharded save needs to know up front (file count, peak per-host I/O, recovery costs).

Memory math for nnx.Linear

For an nnx.Linear(in_dim, out_dim):

  • kernel.size = in_dim * out_dim
  • bias.size = out_dim
  • total = (in_dim + 1) * out_dim

Per shard: total / num_shards. We use float division so the user sees the exact per-shard size, even when uneven. (In real Orbax, uneven shards are handled with padding under the hood.)

Common pitfalls

  • Calling ckptr.save before nnx.split: Orbax operates on pytrees of arrays, not Modules. Always split first.
  • Ignoring save_args: defaults are usually fine but for mixed-precision (e.g. saving as bf16) or compression you need explicit args.
  • Restoring with the wrong sharding template: Orbax can re-shard, but only when the target spec is correct. Mismatch with the actual Mesh shape is a runtime error.

Problem

Implement orbax_save_count(seed, x, features, num_shards):

  1. Cast seed, features, num_shards to ints. Use x.shape[-1] as the input dim.
  2. Build model = nnx.Linear(in_dim, features, rngs=nnx.Rngs(seed)).
  3. total = kernel.size + bias.size. For Dense, that’s (in_dim + 1) * features.
  4. per_shard = total / num_shards (float division).
  5. Return jnp.array([float(total), per_shard]) — a 1-D (2,).

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, in_dim).
  • features: float (cast to int).
  • num_shards: float (cast to int).

Output: 1-D (2,)[total_params, per_shard_size].

Hints

flax nnx orbax checkpoint sharding

Sign in to attempt this problem and view the solution.