medium primitives

Orbax Load (Restore via Template)

Why this matters

Loading a Flax checkpoint is the inverse of saving — you have bytes on disk and you want a params pytree back. The challenge is that a JAX array is not just bytes: it has a shape, a dtype, and (with multi-host training) a sharding spec. Plain “load these bytes” isn’t enough; you need to know how to interpret them.

Orbax solves this with a template (also called the “abstract target” or “restore args”). You build an empty pytree with the same structure, shapes, and dtypes as what you saved, and pass it as a template:

import orbax.checkpoint as ocp

ckptr = ocp.PyTreeCheckpointer()

# Build a template with the right shapes/dtypes (no real data needed):
template = jax.tree_util.tree_map(
    lambda x: ocp.utils.empty_array(x.shape, x.dtype),
    params_skeleton,
)

loaded_params = ckptr.restore("/tmp/my_ckpt", item=template)

The template tells Orbax how to interpret the bytes. Without it, you’d just get a flat dict of arrays with no idea which one is Dense_0/kernel vs Dense_3/bias.

Why we simulate

Orbax’s restore requires a previous successful save on a real filesystem and uses async coordination that doesn’t fit in a sandbox. Instead, we exploit a deterministic property of model.init:

Given the same PRNGKey and the same input shape, model.init produces the same params, every time.

So params_v2 = model.init(jax.random.PRNGKey(seed), x)["params"] is byte-for-byte identical to a freshly-loaded checkpoint of params_v1 that was originally produced from the same key. This is a perfectly valid stand-in for a real Orbax restore — same param tree, same shapes, same dtypes.

What you do once params are loaded

Same as in training: call model.apply({"params": loaded_params}, x). The whole point of loading is to run inference (or resume training) on a fresh process. Apply is identical to how you used the params before saving.

Pitfalls in the real Orbax flow

  • Wrong template shape: shape mismatch raises immediately on restore. Helpful — fails loud.
  • Dtype mismatch silently casts: a saved fp32 ckpt restored with a bf16 template will cast — usually fine, but profile the slowdown if you didn’t expect it.
  • Sharded saves require sharded templates: in multi-host land, the template carries PartitionSpecs; a single-host restore of a sharded ckpt needs explicit consolidation.
  • Forgetting the ["params"] unwrap: model.init(...) returns {"params": ...} (and possibly "batch_stats"); the saved tree should be the inner params, not the outer dict.

Problem

Init nn.Dense(features) with jax.random.PRNGKey(seed) and the given x. Then SIMULATE loading by re-calling model.init with the same key — this gives you loaded_params byte-identical to what a real Orbax restore would produce.

Apply the model to x with loaded_params, then return the flattened output as a 1-D array.

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D) input.
  • features: float (cast to int) — output dim of the Dense layer.

Output: 1-D (N * features,) — the flattened apply result.

Hints

flax orbax checkpointing

Sign in to attempt this problem and view the solution.