medium primitives

NNX Module Print State

Why this matters

nnx modules look like Python objects, but underneath JAX needs to see them as PYTREES — flat collections of arrays — so it can jit, vmap, save, and transmit them. nnx.split(model) is the bridge between the OO façade and the functional core.

graphdef, state = nnx.split(model)
  • graphdef is the static structure: which submodules exist, what types, what shape — but no array values. It’s hashable; JAX traces against it once and reuses the trace.
  • state is the dynamic values: a pytree of nnx.Variable-wrapped arrays, mirroring the module’s nested attribute structure.

nnx.merge(graphdef, state) reconstructs a callable model from the pair. Together they let you do anything pure-functional on state (jit, vmap, save) and merge back when you want to call.

This problem inspects the state of a Linear and counts its leaves — a sanity-check intuition before the deeper split/merge work in problem 12.

API: nnx.split and the state pytree

model = nnx.Linear(4, 6, rngs=nnx.Rngs(0))
graphdef, state = nnx.split(model)

# state is a State pytree — keyed by attribute name
state['kernel']            # the Param wrapper for the kernel
state['kernel'].value      # the underlying (4, 6) array
state['bias'].value        # the underlying (6,) array

# leaves: the underlying arrays themselves
leaves = jax.tree_util.tree_leaves(state)
print(len(leaves))         # 2 — kernel + bias

For a vanilla nnx.Linear, the leaf count is exactly 2 (kernel and bias). If you set use_bias=False, it’s 1.

Why count leaves?

  • Sanity-checking model size. sum(l.size for l in tree_leaves(state)) is the total parameter count.
  • Confirming that submodules registered correctly. A submodule with a Param that doesn’t show up in state is a bug — usually a missed nnx.Param(...) wrapper.
  • Filtering before serialization. You can split state by type (problem 13), feed only the trainable subset to an optimizer, etc.

Worked example

rngs = nnx.Rngs(0)
m = nnx.Linear(in_features=4, out_features=6, rngs=rngs)
graphdef, state = nnx.split(m)
n_leaves = len(jax.tree_util.tree_leaves(state))    # 2
print(state['kernel'].value.shape)                  # (4, 6)
print(state['bias'].value.shape)                    # (6,)

Common pitfalls

  • Calling nnx.split(model) and expecting a params dict. It returns a (graphdef, state) tuple, not a dict.
  • Mistaking leaves for params. A Variable-wrapped scalar is one leaf. A Linear without bias has 1 leaf, not 0.
  • Using len(state) instead of len(tree_leaves(state)). state may contain nested submodules; only tree_leaves flattens recursively.
  • Calling split multiple times. nnx.split produces fresh graphdef + state objects each time. Cache the result if you’ll merge later.

Problem

Write module_state_count(seed, x, hidden):

  1. Build model = nnx.Linear(in_features=x.shape[-1], out_features=int(hidden), rngs=nnx.Rngs(int(seed))).
  2. Split: graphdef, state = nnx.split(model).
  3. Count leaves: leaf_count = len(jax.tree_util.tree_leaves(state)).
  4. Return jnp.array([float(leaf_count), float(int(hidden)), float(x.shape[-1])]).

Expected: leaf_count == 2 (kernel + bias) for a default Linear.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array — only its shape[-1] is used.
  • hidden: int (passed as float).

Output: length-3 array [2.0, hidden, in_features].

Hints

flax nnx split introspection

Sign in to attempt this problem and view the solution.