We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX GraphDef Introspection
Why this matters
Recall the split:
graphdef, state = nnx.split(model)
-
graphdefis the static part — module class hierarchy, attribute names, types, and shape metadata. It’s hashable and small. -
stateis the dynamic part — pytree of arrays.
The hash distinction is what makes nnx.jit fast. JAX jit compiles a
new program for each unique (graphdef, input_avals) combination. If
graphdef is stable across many forward calls, the compile is cached
and reused; only the state tensors change. (Compare with PyTorch’s
torch.compile, where the model object’s identity is the key.)
For this problem we extract shape information from the state (the arrays carry the shapes), but the deeper point is that those shapes are part of the graphdef-implied universe — they’re known statically and won’t change between calls if you reuse the same model.
API: reading shapes off the state
rngs = nnx.Rngs(0)
model = nnx.Linear(in_features=4, out_features=8, rngs=rngs)
graphdef, state = nnx.split(model)
state['kernel'].value.shape # (4, 8) — (in_features, out_features)
state['bias'].value.shape # (8,)
The state is a State pytree keyed by attribute name. For an
nnx.Linear, you have two leaves: kernel (a Param wrapping a 2-D
array) and bias (a Param wrapping a 1-D array).
graphdef vs state — division of labor
Lives in graphdef |
Lives in state |
|---|---|
| Module class | Array values |
| Attribute names | Array shapes (implicit) |
| Variable subclass tags | Array dtypes |
Static config (e.g., use_bias) |
nnx.Variable wrapper instances |
To restate: nnx.merge(graphdef, state) reconstructs a callable model.
graphdef carries the “what kind of thing” info; state carries the
“what values” info.
For most user code you don’t need to peek inside graphdef — it’s a
blob you pass to merge. But knowing it’s there explains why nnx.jit
can fuse re-entrant calls without recompilation, and why surgery via
state pytrees works on already-jitted modules.
Why “static” matters for jit
@nnx.jit
def train_step(model, x, y):
...
Internally, this expands to roughly:
def train_step(model, x, y):
graphdef, state = nnx.split(model)
@jax.jit
def pure(state, x, y):
m = nnx.merge(graphdef, state)
...
return pure(state, x, y)
Every call traces against the SAME graphdef. If you change the model
structure (add a layer, change a flag), the graphdef changes; jax-jit
sees a new hash and recompiles. If you only change parameter VALUES,
it reuses the cache — that’s what makes nnx training loops fast.
Worked example
rngs = nnx.Rngs(0)
model = nnx.Linear(in_features=4, out_features=8, rngs=rngs)
graphdef, state = nnx.split(model)
print('kernel shape:', state['kernel'].value.shape) # (4, 8)
print('bias shape:', state['bias'].value.shape) # (8,)
The shapes match (in_features, out_features) for the kernel and
(out_features,) for the bias.
Common pitfalls
- Trying to JSON-serialize graphdef directly. It contains Python class references; not portable. For checkpoints, save STATE only and keep model code in source.
-
Treating
graphdefas a parameter. It’s not; it’s static. Don’t pass it to optimizers. -
Reading kernel via
state['kernel']without.value. The state pytree leaves arennx.Param(or other Variable) wrappers; unwrap with.valueto get the array. -
Off-by-one on bias shape. Bias is
(out_features,), not(in_features,).
Problem
Write graphdef_call_static(seed, x, hidden):
-
Build
model = nnx.Linear(in_features=x.shape[-1], out_features=int(hidden), rngs=nnx.Rngs(int(seed))). -
Split:
graphdef, state = nnx.split(model). -
Read shapes:
-
kernel_in = state['kernel'].value.shape[0] -
kernel_out = state['kernel'].value.shape[1] -
bias_dim = state['bias'].value.shape[0]
-
-
Return
jnp.array([float(kernel_in), float(kernel_out), float(bias_dim)]).
Expected: [in_features, hidden, hidden].
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
hidden: int (passed as float).
Output: length-3 array [in_features, hidden, hidden].
Hints
Sign in to attempt this problem and view the solution.