We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Module Print State
Why this matters
nnx modules look like Python objects, but underneath JAX needs to see them
as PYTREES — flat collections of arrays — so it can jit, vmap, save, and
transmit them. nnx.split(model) is the bridge between the OO façade
and the functional core.
graphdef, state = nnx.split(model)
-
graphdefis the static structure: which submodules exist, what types, what shape — but no array values. It’s hashable; JAX traces against it once and reuses the trace. -
stateis the dynamic values: a pytree ofnnx.Variable-wrapped arrays, mirroring the module’s nested attribute structure.
nnx.merge(graphdef, state) reconstructs a callable model from the pair.
Together they let you do anything pure-functional on state (jit, vmap,
save) and merge back when you want to call.
This problem inspects the state of a Linear and counts its leaves — a sanity-check intuition before the deeper split/merge work in problem 12.
API: nnx.split and the state pytree
model = nnx.Linear(4, 6, rngs=nnx.Rngs(0))
graphdef, state = nnx.split(model)
# state is a State pytree — keyed by attribute name
state['kernel'] # the Param wrapper for the kernel
state['kernel'].value # the underlying (4, 6) array
state['bias'].value # the underlying (6,) array
# leaves: the underlying arrays themselves
leaves = jax.tree_util.tree_leaves(state)
print(len(leaves)) # 2 — kernel + bias
For a vanilla nnx.Linear, the leaf count is exactly 2 (kernel and
bias). If you set use_bias=False, it’s 1.
Why count leaves?
-
Sanity-checking model size.
sum(l.size for l in tree_leaves(state))is the total parameter count. -
Confirming that submodules registered correctly. A submodule with
a Param that doesn’t show up in
stateis a bug — usually a missednnx.Param(...)wrapper. - Filtering before serialization. You can split state by type (problem 13), feed only the trainable subset to an optimizer, etc.
Worked example
rngs = nnx.Rngs(0)
m = nnx.Linear(in_features=4, out_features=6, rngs=rngs)
graphdef, state = nnx.split(m)
n_leaves = len(jax.tree_util.tree_leaves(state)) # 2
print(state['kernel'].value.shape) # (4, 6)
print(state['bias'].value.shape) # (6,)
Common pitfalls
-
Calling
nnx.split(model)and expecting aparamsdict. It returns a(graphdef, state)tuple, not a dict. - Mistaking leaves for params. A Variable-wrapped scalar is one leaf. A Linear without bias has 1 leaf, not 0.
-
Using
len(state)instead oflen(tree_leaves(state)).statemay contain nested submodules; onlytree_leavesflattens recursively. -
Calling split multiple times.
nnx.splitproduces fresh graphdef + state objects each time. Cache the result if you’ll merge later.
Problem
Write module_state_count(seed, x, hidden):
-
Build
model = nnx.Linear(in_features=x.shape[-1], out_features=int(hidden), rngs=nnx.Rngs(int(seed))). -
Split:
graphdef, state = nnx.split(model). -
Count leaves:
leaf_count = len(jax.tree_util.tree_leaves(state)). -
Return
jnp.array([float(leaf_count), float(int(hidden)), float(x.shape[-1])]).
Expected: leaf_count == 2 (kernel + bias) for a default Linear.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array — only itsshape[-1]is used. -
hidden: int (passed as float).
Output: length-3 array [2.0, hidden, in_features].
Hints
Sign in to attempt this problem and view the solution.