We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX State Split/Merge
Why this matters
nnx.split and nnx.merge are the bridge between OO nnx code and pure
JAX. They are the two functions that make nnx.jit, nnx.vmap, Orbax
checkpointing, and pretty much every advanced use case work.
graphdef, state = nnx.split(model) # disassemble
merged = nnx.merge(graphdef, state) # reassemble
The contract: round-tripping a model through split/merge produces a
structurally and numerically identical model. Calling the original on x
and calling the merged on x returns the same array, bit-for-bit.
Why does this matter? Because pure JAX transforms (jax.jit, jax.vmap,
jax.grad) operate on pure functions and pytrees, not on Python objects.
state is a pytree of arrays; you can transform it freely. Then merge it
back with the (hashable, static) graphdef to get a callable model again.
API: split / merge
model = nnx.Linear(4, 8, rngs=nnx.Rngs(0))
graphdef, state = nnx.split(model)
# graphdef: static structural metadata (hashable; cached by jit)
# state: pytree of nnx.Variable wrappers
merged = nnx.merge(graphdef, state)
# merged is a fresh Module instance; calling it gives the same result
The merged module is a new Python object — merged is not model — but
its parameters are the same arrays, so its forward output is identical.
Where this gets used
Inside nnx.jit, the implementation is roughly:
def nnx_jit_apply(model, x):
graphdef, state = nnx.split(model)
@jax.jit
def pure_fn(state, x):
m = nnx.merge(graphdef, state)
return m(x)
return pure_fn(state, x)
Each jit invocation: split → traced pure function takes state as input →
merge inside the trace → call. JAX never sees a Python module; it sees a
pytree.
Worked example
model = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0))
x = jnp.array([1.0, 2.0, 3.0])
orig_out = model(x)
graphdef, state = nnx.split(model)
merged = nnx.merge(graphdef, state)
merged_out = merged(x)
print(jnp.max(jnp.abs(orig_out - merged_out))) # 0.0
Common pitfalls
-
Treating
stateas a flat dict. It’s aStatepytree with the module’s nested structure mirrored. Usetree_leavesto flatten. -
Calling
mergewith mismatched graphdef and state. They must come from the same model (or a structurally compatible one). -
Reusing the same model and merge. After
merged = nnx.merge(...), you have two separate objects. Mutating one doesn’t affect the other. For surgery, that’s a feature; for sharing, you have to wire it up explicitly. -
Splitting in a hot loop. Cache
(graphdef, state)if you’ll callmergemultiple times — avoids unnecessary work.
Problem
Write split_merge_round_trip(seed, x, features):
-
Build
model = nnx.Linear(in_features=x.shape[-1], out_features=int(features), rngs=nnx.Rngs(int(seed))). -
orig_out = model(x). -
graphdef, state = nnx.split(model); merged = nnx.merge(graphdef, state). -
merged_out = merged(x). -
abs_diff = jnp.max(jnp.abs(orig_out - merged_out)). -
Return
jnp.array([float(orig_out.sum()), float(merged_out.sum()), float(abs_diff)]).
Expected: orig_out.sum() == merged_out.sum() and abs_diff == 0.0.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
features: int (passed as float).
Output: length-3 array [s, s, 0.0].
Hints
Sign in to attempt this problem and view the solution.