medium primitives

Orbax Save (Tree-Leaf Count)

Why this matters

Saving a model in Flax is not model.save("path.pt") — Flax modules have no state. The state lives in the params pytree that you threaded through your training loop. To serialize a Flax model, you serialize the param pytree.

The standard tool is Orbax (orbax-checkpoint), the official JAX-ecosystem checkpointing library. It handles:

  • Async writes (don’t block training).
  • Multi-host coordination (every replica writes its shard).
  • Atomic commits (a partially-written checkpoint is invalid).
  • Versioning and metadata.

The save API

import orbax.checkpoint as ocp

ckptr = ocp.PyTreeCheckpointer()
ckptr.save("/tmp/my_ckpt", params)         # synchronous
# or:
async_ckptr = ocp.AsyncCheckpointer(ocp.PyTreeCheckpointHandler())
async_ckptr.save("/tmp/my_ckpt", params)   # returns immediately
async_ckptr.wait_until_finished()           # block when you need the bytes on disk

save walks the pytree, writes each leaf as a separate file (or shard), and emits a metadata json that records the tree structure. The tree structure is what you actually save — array bytes are just leaves.

What’s a leaf, exactly?

Flax’s params for a model like nn.Sequential([Dense(D)] * N) is a nested dict:

{
  "Dense_0": {"kernel": <array>, "bias": <array>},
  "Dense_1": {"kernel": <array>, "bias": <array>},
  ...
  "Dense_{N-1}": {"kernel": <array>, "bias": <array>},
}

Each <array> is a leaf. With N Dense layers, you have 2N leaves (kernel + bias each). jax.tree_util.tree_leaves(params) returns the flat list — its length is the leaf count.

Why do you care? Because:

  • Orbax writes one file per leaf. A model with 2N=200 layers means 200 files in your checkpoint directory.
  • Sharded saves split each leaf across devices, not whole models.
  • tree_leaves is the introspection step you do BEFORE saving to sanity-check what’s being serialized (especially for nested models where it’s easy to lose track of param count).

Why not run the actual save?

Orbax requires a writable filesystem and is async-coordinated; it won’t run reliably in this sandbox. The introspection (tree_leaves + count) is the part that ALWAYS happens before any Orbax call — and it’s what you’d use to verify your params look right.

Problem

Build a tiny Flax model that’s nn.Dense(D) repeated num_layers times (where D = x.shape[-1]), init it with jax.random.PRNGKey(seed) and a sample input x, then count the leaves in the resulting params pytree:

leaves = jax.tree_util.tree_leaves(params)
return jnp.array([float(len(leaves))])

Each Dense layer contributes 2 leaves (kernel + bias), so the count should be 2 * num_layers.

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D) sample input — used only for shape during init.
  • num_layers: float (cast to int).

Output: 1-D (1,)[leaf_count] (always equals 2 * num_layers).

Hints

flax orbax checkpointing

Sign in to attempt this problem and view the solution.