We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Graphdef vs State Debug
Why this matters
nnx.split(model) returns two things:
-
graphdef— STATIC structural metadata. Class names, attribute names, the wiring between submodules, the dtype/shape of every parameter. Hashable. Cached byjitso the same architecture compiles once. -
state— DYNAMIC value pytree. Just the JAX arrays. The thing optimizers update, checkpoints save, andjax.jittransforms.
Knowing which is which is the unlock for everything advanced in nnx: jit caching, distributed sharding, mixed-precision casts, surgery, checkpointing. You manipulate the state with normal JAX tools; you preserve the graphdef as opaque structure.
This problem is a hands-on inspection: split, count leaves, sum
them. The numbers themselves don’t matter as much as the workflow:
when something looks wrong with your model, nnx.split is the
first move to inspect it.
Worked example
model = TwoLayer(in_features=3, hidden=4, rngs=nnx.Rngs(0))
graphdef, state = nnx.split(model)
# state is a pytree (specifically: an `nnx.State` mapping).
# Flatten with jax.tree_util.tree_leaves:
leaves = jax.tree_util.tree_leaves(state)
print(len(leaves)) # 4 — l1.kernel, l1.bias, l2.kernel, l2.bias
for leaf in leaves:
print(leaf.shape, leaf.dtype)
Each leaf is a JAX array (the param values), unwrapped from
its nnx.Param wrapper. jax.tree_util.tree_leaves flattens
out the nested-mapping structure.
For a TwoLayer with l1 and l2 (both nnx.Linear), the
leaves are:
-
l1.kernel.value—(in_features, hidden) -
l1.bias.value—(hidden,) -
l2.kernel.value—(hidden, hidden) -
l2.bias.value—(hidden,)
Four leaves total. The order depends on nnx‘s
deterministic-attribute traversal — but the count is robust.
Why count leaves?
A common debugging move: “I added a parameter and it isn’t
showing up in the optimizer.” Split the model. Count the leaves.
Compare to what you expected. If the count is wrong, the param
isn’t registered (often: a plain Python list instead of nnx.List,
or a raw jnp.array instead of nnx.Param).
Sum every leaf? More of an integrity check — useful for verifying
that two (graphdef, state) pairs from different sources have
the same parameters (e.g., before/after a checkpoint round-trip).
API recap
graphdef, state = nnx.split(model)
leaves = jax.tree_util.tree_leaves(state)
num_leaves = len(leaves)
total = sum(jnp.sum(leaf) for leaf in leaves)
# or equivalently: jax.tree_util.tree_reduce(operator.add, jax.tree_map(jnp.sum, state))
The tree_leaves flattens; len(...) counts; the generator
expression sums each leaf and Python adds them. The result is a
JAX scalar; float(...) to make it a Python float.
Compare with the original model
forward_output_sum = float(model(x).sum())
graphdef, state = nnx.split(model)
# ... inspection ...
Splitting AFTER calling the model is fine — model is unchanged
by __call__ (in this problem; if there were nnx.Variable
counters mutating, this wouldn’t be true). The forward output
sum and the parameter sum are independent observables.
Common pitfalls
-
Treating
stateas a flat dict. It’s a nested pytree mirroring the module structure. Usetree_leavesto flatten. -
Forgetting
.valueto access the array. When you DO walkstateby hand (e.g.,state['l1']['kernel']), you get annnx.Paramwrapper;.valuegets the array.tree_leavesalready unwraps for you. -
Counting
state.keys()as leaves. That counts top-level submodule names (2 for aTwoLayer), not the actual parameter arrays (4). -
Forgetting that biases are leaves. A
Linearhas TWO leaves (kernel + bias). ALinearwithuse_bias=Falsehas ONE.
Problem
Write graphdef_vs_state(seed, x, features, hidden):
-
Define a
TwoLayer(nnx.Module)withl1 = nnx.Linear(in, hidden)andl2 = nnx.Linear(hidden, hidden). Forward:l2(l1(x)). -
Build the model. Compute
forward_output_sum = float(model(x).sum()). -
Split:
graphdef, state = nnx.split(model). Flatten withjax.tree_util.tree_leaves(state). -
num_leaves = len(leaves).total_param_size = float(sum(jnp.sum(leaf) for leaf in leaves)). -
Return
jnp.array([float(num_leaves), total_param_size, forward_output_sum]).
For this architecture, num_leaves is always 4: l1.kernel,
l1.bias, l2.kernel, l2.bias.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
features: int (passed as float; here unused — kept for ABI consistency). -
hidden: int (passed as float).
Output: length-3 [num_leaves, total_param_sum, forward_output_sum].
Hints
Sign in to attempt this problem and view the solution.