We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Orbax Load (Restore via Template)
Why this matters
Loading a Flax checkpoint is the inverse of saving — you have bytes
on disk and you want a params pytree back. The challenge is that
a JAX array is not just bytes: it has a shape, a dtype, and
(with multi-host training) a sharding spec. Plain “load these
bytes” isn’t enough; you need to know how to interpret them.
Orbax solves this with a template (also called the “abstract target” or “restore args”). You build an empty pytree with the same structure, shapes, and dtypes as what you saved, and pass it as a template:
import orbax.checkpoint as ocp
ckptr = ocp.PyTreeCheckpointer()
# Build a template with the right shapes/dtypes (no real data needed):
template = jax.tree_util.tree_map(
lambda x: ocp.utils.empty_array(x.shape, x.dtype),
params_skeleton,
)
loaded_params = ckptr.restore("/tmp/my_ckpt", item=template)
The template tells Orbax how to interpret the bytes. Without it,
you’d just get a flat dict of arrays with no idea which one is
Dense_0/kernel vs Dense_3/bias.
Why we simulate
Orbax’s restore requires a previous successful save on a real
filesystem and uses async coordination that doesn’t fit in a sandbox.
Instead, we exploit a deterministic property of model.init:
Given the same PRNGKey and the same input shape,
model.initproduces the same params, every time.
So params_v2 = model.init(jax.random.PRNGKey(seed), x)["params"]
is byte-for-byte identical to a freshly-loaded checkpoint of
params_v1 that was originally produced from the same key. This is
a perfectly valid stand-in for a real Orbax restore — same param
tree, same shapes, same dtypes.
What you do once params are loaded
Same as in training: call model.apply({"params": loaded_params}, x).
The whole point of loading is to run inference (or resume training)
on a fresh process. Apply is identical to how you used the params
before saving.
Pitfalls in the real Orbax flow
- Wrong template shape: shape mismatch raises immediately on restore. Helpful — fails loud.
- Dtype mismatch silently casts: a saved fp32 ckpt restored with a bf16 template will cast — usually fine, but profile the slowdown if you didn’t expect it.
-
Sharded saves require sharded templates: in multi-host land,
the template carries
PartitionSpecs; a single-host restore of a sharded ckpt needs explicit consolidation. -
Forgetting the
["params"]unwrap:model.init(...)returns{"params": ...}(and possibly"batch_stats"); the saved tree should be the innerparams, not the outer dict.
Problem
Init nn.Dense(features) with jax.random.PRNGKey(seed) and the
given x. Then SIMULATE loading by re-calling model.init with
the same key — this gives you loaded_params byte-identical to
what a real Orbax restore would produce.
Apply the model to x with loaded_params, then return the
flattened output as a 1-D array.
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D)input. -
features: float (cast to int) — output dim of the Dense layer.
Output: 1-D (N * features,) — the flattened apply result.
Hints
Sign in to attempt this problem and view the solution.