medium primitives

NNX GraphDef Introspection

Why this matters

Recall the split:

graphdef, state = nnx.split(model)
  • graphdef is the static part — module class hierarchy, attribute names, types, and shape metadata. It’s hashable and small.
  • state is the dynamic part — pytree of arrays.

The hash distinction is what makes nnx.jit fast. JAX jit compiles a new program for each unique (graphdef, input_avals) combination. If graphdef is stable across many forward calls, the compile is cached and reused; only the state tensors change. (Compare with PyTorch’s torch.compile, where the model object’s identity is the key.)

For this problem we extract shape information from the state (the arrays carry the shapes), but the deeper point is that those shapes are part of the graphdef-implied universe — they’re known statically and won’t change between calls if you reuse the same model.

API: reading shapes off the state

rngs = nnx.Rngs(0)
model = nnx.Linear(in_features=4, out_features=8, rngs=rngs)
graphdef, state = nnx.split(model)

state['kernel'].value.shape    # (4, 8) — (in_features, out_features)
state['bias'].value.shape      # (8,)

The state is a State pytree keyed by attribute name. For an nnx.Linear, you have two leaves: kernel (a Param wrapping a 2-D array) and bias (a Param wrapping a 1-D array).

graphdef vs state — division of labor

Lives in graphdef Lives in state
Module class Array values
Attribute names Array shapes (implicit)
Variable subclass tags Array dtypes
Static config (e.g., use_bias) nnx.Variable wrapper instances

To restate: nnx.merge(graphdef, state) reconstructs a callable model. graphdef carries the “what kind of thing” info; state carries the “what values” info.

For most user code you don’t need to peek inside graphdef — it’s a blob you pass to merge. But knowing it’s there explains why nnx.jit can fuse re-entrant calls without recompilation, and why surgery via state pytrees works on already-jitted modules.

Why “static” matters for jit

@nnx.jit
def train_step(model, x, y):
    ...

Internally, this expands to roughly:

def train_step(model, x, y):
    graphdef, state = nnx.split(model)
    @jax.jit
    def pure(state, x, y):
        m = nnx.merge(graphdef, state)
        ...
    return pure(state, x, y)

Every call traces against the SAME graphdef. If you change the model structure (add a layer, change a flag), the graphdef changes; jax-jit sees a new hash and recompiles. If you only change parameter VALUES, it reuses the cache — that’s what makes nnx training loops fast.

Worked example

rngs = nnx.Rngs(0)
model = nnx.Linear(in_features=4, out_features=8, rngs=rngs)
graphdef, state = nnx.split(model)
print('kernel shape:', state['kernel'].value.shape)   # (4, 8)
print('bias shape:',   state['bias'].value.shape)     # (8,)

The shapes match (in_features, out_features) for the kernel and (out_features,) for the bias.

Common pitfalls

  • Trying to JSON-serialize graphdef directly. It contains Python class references; not portable. For checkpoints, save STATE only and keep model code in source.
  • Treating graphdef as a parameter. It’s not; it’s static. Don’t pass it to optimizers.
  • Reading kernel via state['kernel'] without .value. The state pytree leaves are nnx.Param (or other Variable) wrappers; unwrap with .value to get the array.
  • Off-by-one on bias shape. Bias is (out_features,), not (in_features,).

Problem

Write graphdef_call_static(seed, x, hidden):

  1. Build model = nnx.Linear(in_features=x.shape[-1], out_features=int(hidden), rngs=nnx.Rngs(int(seed))).
  2. Split: graphdef, state = nnx.split(model).
  3. Read shapes:
    • kernel_in = state['kernel'].value.shape[0]
    • kernel_out = state['kernel'].value.shape[1]
    • bias_dim = state['bias'].value.shape[0]
  4. Return jnp.array([float(kernel_in), float(kernel_out), float(bias_dim)]).

Expected: [in_features, hidden, hidden].

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • hidden: int (passed as float).

Output: length-3 array [in_features, hidden, hidden].

Hints

flax nnx graphdef introspection

Sign in to attempt this problem and view the solution.