We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Orbax Save (Tree-Leaf Count)
Why this matters
Saving a model in Flax is not model.save("path.pt") — Flax modules
have no state. The state lives in the params pytree that you
threaded through your training loop. To serialize a Flax model, you
serialize the param pytree.
The standard tool is Orbax (orbax-checkpoint), the official
JAX-ecosystem checkpointing library. It handles:
- Async writes (don’t block training).
- Multi-host coordination (every replica writes its shard).
- Atomic commits (a partially-written checkpoint is invalid).
- Versioning and metadata.
The save API
import orbax.checkpoint as ocp
ckptr = ocp.PyTreeCheckpointer()
ckptr.save("/tmp/my_ckpt", params) # synchronous
# or:
async_ckptr = ocp.AsyncCheckpointer(ocp.PyTreeCheckpointHandler())
async_ckptr.save("/tmp/my_ckpt", params) # returns immediately
async_ckptr.wait_until_finished() # block when you need the bytes on disk
save walks the pytree, writes each leaf as a separate file (or shard),
and emits a metadata json that records the tree structure. The tree
structure is what you actually save — array bytes are just leaves.
What’s a leaf, exactly?
Flax’s params for a model like nn.Sequential([Dense(D)] * N) is a
nested dict:
{
"Dense_0": {"kernel": <array>, "bias": <array>},
"Dense_1": {"kernel": <array>, "bias": <array>},
...
"Dense_{N-1}": {"kernel": <array>, "bias": <array>},
}
Each <array> is a leaf. With N Dense layers, you have 2N leaves
(kernel + bias each). jax.tree_util.tree_leaves(params) returns the
flat list — its length is the leaf count.
Why do you care? Because:
- Orbax writes one file per leaf. A model with 2N=200 layers means 200 files in your checkpoint directory.
- Sharded saves split each leaf across devices, not whole models.
-
tree_leavesis the introspection step you do BEFORE saving to sanity-check what’s being serialized (especially for nested models where it’s easy to lose track of param count).
Why not run the actual save?
Orbax requires a writable filesystem and is async-coordinated; it
won’t run reliably in this sandbox. The introspection
(tree_leaves + count) is the part that ALWAYS happens before any
Orbax call — and it’s what you’d use to verify your params look right.
Problem
Build a tiny Flax model that’s nn.Dense(D) repeated num_layers
times (where D = x.shape[-1]), init it with jax.random.PRNGKey(seed)
and a sample input x, then count the leaves in the resulting
params pytree:
leaves = jax.tree_util.tree_leaves(params)
return jnp.array([float(len(leaves))])
Each Dense layer contributes 2 leaves (kernel + bias), so the
count should be 2 * num_layers.
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D)sample input — used only for shape during init. -
num_layers: float (cast to int).
Output: 1-D (1,) — [leaf_count] (always equals 2 * num_layers).
Hints
Sign in to attempt this problem and view the solution.