medium primitives

Distributed Checkpoint (Sharding Math)

Why this matters

Models that don’t fit on one device — LLaMA-405B, GPT-3, Switch Transformers — are necessarily sharded: each device holds a fraction of each parameter (or whole disjoint params). When you save such a model, you can’t simply gather() everything to one host (it doesn’t fit in host memory either!). You need a distributed checkpoint: every replica writes its own shard.

Orbax’s distributed model

import orbax.checkpoint as ocp

# In a multi-host JAX program, the params are sharded via PartitionSpec:
sharding = jax.sharding.NamedSharding(mesh, P("model", None))
sharded_params = jax.tree_util.tree_map(
    lambda x: jax.device_put(x, sharding),
    params,
)

handler = ocp.PyTreeCheckpointHandler()
ckptr = ocp.Checkpointer(handler)
ckptr.save("/tmp/dist_ckpt", sharded_params)
# Each replica writes its slice of each tensor to a per-shard file.
# Orbax coordinates so the global metadata is consistent.

The on-disk layout looks roughly like:

/tmp/dist_ckpt/
  _METADATA              <- tree structure + dtype info
  Dense_0.kernel/
    shard_0.array        <- replica 0's slice
    shard_1.array        <- replica 1's slice
    ...
    shard_{N-1}.array
  Dense_0.bias/
    shard_0.array
    ...

Each shard is roughly total_param_size / num_shards bytes (assuming even sharding — uneven shards happen when the param dim doesn’t divide by num_shards, but Orbax handles that with padding).

Why even sharding matters

  • I/O parallelism: each shard writes to a separate file simultaneously — checkpoint time scales as O(P / N_shards) instead of O(P).
  • Memory ceiling: no single host needs to hold the full model.
  • Resilience: lose one shard’s disk, you can re-stream just that shard from the original distributed model.

Restore semantics: ArrayRestoreArgs

On restore, you also need to provide a sharding template:

restore_args = jax.tree_util.tree_map(
    lambda spec: ocp.ArrayRestoreArgs(sharding=spec),
    target_sharding_tree,
)
loaded = ckptr.restore("/tmp/dist_ckpt", restore_args=restore_args)

The restore args tell Orbax how to redistribute the shards if the new mesh differs from the save-time mesh. Common case: save on 8 hosts, restore on 16 — Orbax re-shards on the fly.

Why we don’t actually call Orbax

Distributed Orbax requires:

  • A real multi-host JAX setup (or jax.distributed.initialize).
  • A real filesystem with concurrent-write semantics.
  • Coordinated processes.

None of these exist in the sandbox. Instead, we compute the checkpoint accounting: total param count, plus the per-shard size, which is what an engineer planning a sharded save needs to know up front.

Problem

Init nn.Dense(features) with PRNGKey(seed) and x. Compute:

  1. total_params = kernel.size + bias.size.
  2. per_shard = total_params / num_shards (float division).
  3. Return [total_params, per_shard] as a 1-D (2,) array.

For nn.Dense(features) applied to (N, in_dim) input:

  • kernel has shape (in_dim, features)kernel.size == in_dim * features.
  • bias has shape (features,)bias.size == features.
  • total = in_dim * features + features = (in_dim + 1) * features.

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, in_dim).
  • num_shards: float (cast to int).
  • features: float (cast to int).

Output: 1-D (2,)[total_params, per_shard].

Hints

flax orbax distributed

Sign in to attempt this problem and view the solution.