medium primitives

NNX Graphdef vs State Debug

Why this matters

nnx.split(model) returns two things:

  • graphdef — STATIC structural metadata. Class names, attribute names, the wiring between submodules, the dtype/shape of every parameter. Hashable. Cached by jit so the same architecture compiles once.
  • state — DYNAMIC value pytree. Just the JAX arrays. The thing optimizers update, checkpoints save, and jax.jit transforms.

Knowing which is which is the unlock for everything advanced in nnx: jit caching, distributed sharding, mixed-precision casts, surgery, checkpointing. You manipulate the state with normal JAX tools; you preserve the graphdef as opaque structure.

This problem is a hands-on inspection: split, count leaves, sum them. The numbers themselves don’t matter as much as the workflow: when something looks wrong with your model, nnx.split is the first move to inspect it.

Worked example

model = TwoLayer(in_features=3, hidden=4, rngs=nnx.Rngs(0))

graphdef, state = nnx.split(model)

# state is a pytree (specifically: an `nnx.State` mapping).
# Flatten with jax.tree_util.tree_leaves:
leaves = jax.tree_util.tree_leaves(state)
print(len(leaves))     # 4 — l1.kernel, l1.bias, l2.kernel, l2.bias
for leaf in leaves:
    print(leaf.shape, leaf.dtype)

Each leaf is a JAX array (the param values), unwrapped from its nnx.Param wrapper. jax.tree_util.tree_leaves flattens out the nested-mapping structure.

For a TwoLayer with l1 and l2 (both nnx.Linear), the leaves are:

  1. l1.kernel.value(in_features, hidden)
  2. l1.bias.value(hidden,)
  3. l2.kernel.value(hidden, hidden)
  4. l2.bias.value(hidden,)

Four leaves total. The order depends on nnx‘s deterministic-attribute traversal — but the count is robust.

Why count leaves?

A common debugging move: “I added a parameter and it isn’t showing up in the optimizer.” Split the model. Count the leaves. Compare to what you expected. If the count is wrong, the param isn’t registered (often: a plain Python list instead of nnx.List, or a raw jnp.array instead of nnx.Param).

Sum every leaf? More of an integrity check — useful for verifying that two (graphdef, state) pairs from different sources have the same parameters (e.g., before/after a checkpoint round-trip).

API recap

graphdef, state = nnx.split(model)

leaves = jax.tree_util.tree_leaves(state)
num_leaves = len(leaves)

total = sum(jnp.sum(leaf) for leaf in leaves)
# or equivalently: jax.tree_util.tree_reduce(operator.add, jax.tree_map(jnp.sum, state))

The tree_leaves flattens; len(...) counts; the generator expression sums each leaf and Python adds them. The result is a JAX scalar; float(...) to make it a Python float.

Compare with the original model

forward_output_sum = float(model(x).sum())
graphdef, state = nnx.split(model)
# ... inspection ...

Splitting AFTER calling the model is fine — model is unchanged by __call__ (in this problem; if there were nnx.Variable counters mutating, this wouldn’t be true). The forward output sum and the parameter sum are independent observables.

Common pitfalls

  • Treating state as a flat dict. It’s a nested pytree mirroring the module structure. Use tree_leaves to flatten.
  • Forgetting .value to access the array. When you DO walk state by hand (e.g., state['l1']['kernel']), you get an nnx.Param wrapper; .value gets the array. tree_leaves already unwraps for you.
  • Counting state.keys() as leaves. That counts top-level submodule names (2 for a TwoLayer), not the actual parameter arrays (4).
  • Forgetting that biases are leaves. A Linear has TWO leaves (kernel + bias). A Linear with use_bias=False has ONE.

Problem

Write graphdef_vs_state(seed, x, features, hidden):

  1. Define a TwoLayer(nnx.Module) with l1 = nnx.Linear(in, hidden) and l2 = nnx.Linear(hidden, hidden). Forward: l2(l1(x)).
  2. Build the model. Compute forward_output_sum = float(model(x).sum()).
  3. Split: graphdef, state = nnx.split(model). Flatten with jax.tree_util.tree_leaves(state).
  4. num_leaves = len(leaves). total_param_size = float(sum(jnp.sum(leaf) for leaf in leaves)).
  5. Return jnp.array([float(num_leaves), total_param_size, forward_output_sum]).

For this architecture, num_leaves is always 4: l1.kernel, l1.bias, l2.kernel, l2.bias.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • features: int (passed as float; here unused — kept for ABI consistency).
  • hidden: int (passed as float).

Output: length-3 [num_leaves, total_param_sum, forward_output_sum].

Hints

flax nnx graphdef state debug introspection

Sign in to attempt this problem and view the solution.