We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Distributed Checkpoint (Sharding Math)
Why this matters
Models that don’t fit on one device — LLaMA-405B, GPT-3, Switch
Transformers — are necessarily sharded: each device holds a
fraction of each parameter (or whole disjoint params). When you
save such a model, you can’t simply gather() everything to one
host (it doesn’t fit in host memory either!). You need a
distributed checkpoint: every replica writes its own shard.
Orbax’s distributed model
import orbax.checkpoint as ocp
# In a multi-host JAX program, the params are sharded via PartitionSpec:
sharding = jax.sharding.NamedSharding(mesh, P("model", None))
sharded_params = jax.tree_util.tree_map(
lambda x: jax.device_put(x, sharding),
params,
)
handler = ocp.PyTreeCheckpointHandler()
ckptr = ocp.Checkpointer(handler)
ckptr.save("/tmp/dist_ckpt", sharded_params)
# Each replica writes its slice of each tensor to a per-shard file.
# Orbax coordinates so the global metadata is consistent.
The on-disk layout looks roughly like:
/tmp/dist_ckpt/
_METADATA <- tree structure + dtype info
Dense_0.kernel/
shard_0.array <- replica 0's slice
shard_1.array <- replica 1's slice
...
shard_{N-1}.array
Dense_0.bias/
shard_0.array
...
Each shard is roughly total_param_size / num_shards bytes (assuming
even sharding — uneven shards happen when the param dim doesn’t divide
by num_shards, but Orbax handles that with padding).
Why even sharding matters
- I/O parallelism: each shard writes to a separate file simultaneously — checkpoint time scales as O(P / N_shards) instead of O(P).
- Memory ceiling: no single host needs to hold the full model.
- Resilience: lose one shard’s disk, you can re-stream just that shard from the original distributed model.
Restore semantics: ArrayRestoreArgs
On restore, you also need to provide a sharding template:
restore_args = jax.tree_util.tree_map(
lambda spec: ocp.ArrayRestoreArgs(sharding=spec),
target_sharding_tree,
)
loaded = ckptr.restore("/tmp/dist_ckpt", restore_args=restore_args)
The restore args tell Orbax how to redistribute the shards if the new mesh differs from the save-time mesh. Common case: save on 8 hosts, restore on 16 — Orbax re-shards on the fly.
Why we don’t actually call Orbax
Distributed Orbax requires:
-
A real multi-host JAX setup (or
jax.distributed.initialize). - A real filesystem with concurrent-write semantics.
- Coordinated processes.
None of these exist in the sandbox. Instead, we compute the checkpoint accounting: total param count, plus the per-shard size, which is what an engineer planning a sharded save needs to know up front.
Problem
Init nn.Dense(features) with PRNGKey(seed) and x. Compute:
-
total_params = kernel.size + bias.size. -
per_shard = total_params / num_shards(float division). -
Return
[total_params, per_shard]as a 1-D(2,)array.
For nn.Dense(features) applied to (N, in_dim) input:
-
kernelhas shape(in_dim, features)⇒kernel.size == in_dim * features. -
biashas shape(features,)⇒bias.size == features. -
total = in_dim * features + features = (in_dim + 1) * features.
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, in_dim). -
num_shards: float (cast to int). -
features: float (cast to int).
Output: 1-D (2,) — [total_params, per_shard].
Hints
Sign in to attempt this problem and view the solution.