medium primitives

NNX State Split/Merge

Why this matters

nnx.split and nnx.merge are the bridge between OO nnx code and pure JAX. They are the two functions that make nnx.jit, nnx.vmap, Orbax checkpointing, and pretty much every advanced use case work.

graphdef, state = nnx.split(model)     # disassemble
merged = nnx.merge(graphdef, state)    # reassemble

The contract: round-tripping a model through split/merge produces a structurally and numerically identical model. Calling the original on x and calling the merged on x returns the same array, bit-for-bit.

Why does this matter? Because pure JAX transforms (jax.jit, jax.vmap, jax.grad) operate on pure functions and pytrees, not on Python objects. state is a pytree of arrays; you can transform it freely. Then merge it back with the (hashable, static) graphdef to get a callable model again.

API: split / merge

model = nnx.Linear(4, 8, rngs=nnx.Rngs(0))

graphdef, state = nnx.split(model)
# graphdef: static structural metadata (hashable; cached by jit)
# state:    pytree of nnx.Variable wrappers

merged = nnx.merge(graphdef, state)
# merged is a fresh Module instance; calling it gives the same result

The merged module is a new Python object — merged is not model — but its parameters are the same arrays, so its forward output is identical.

Where this gets used

Inside nnx.jit, the implementation is roughly:

def nnx_jit_apply(model, x):
    graphdef, state = nnx.split(model)
    @jax.jit
    def pure_fn(state, x):
        m = nnx.merge(graphdef, state)
        return m(x)
    return pure_fn(state, x)

Each jit invocation: split → traced pure function takes state as input → merge inside the trace → call. JAX never sees a Python module; it sees a pytree.

Worked example

model = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0))
x = jnp.array([1.0, 2.0, 3.0])

orig_out = model(x)
graphdef, state = nnx.split(model)
merged = nnx.merge(graphdef, state)
merged_out = merged(x)

print(jnp.max(jnp.abs(orig_out - merged_out)))   # 0.0

Common pitfalls

  • Treating state as a flat dict. It’s a State pytree with the module’s nested structure mirrored. Use tree_leaves to flatten.
  • Calling merge with mismatched graphdef and state. They must come from the same model (or a structurally compatible one).
  • Reusing the same model and merge. After merged = nnx.merge(...), you have two separate objects. Mutating one doesn’t affect the other. For surgery, that’s a feature; for sharing, you have to wire it up explicitly.
  • Splitting in a hot loop. Cache (graphdef, state) if you’ll call merge multiple times — avoids unnecessary work.

Problem

Write split_merge_round_trip(seed, x, features):

  1. Build model = nnx.Linear(in_features=x.shape[-1], out_features=int(features), rngs=nnx.Rngs(int(seed))).
  2. orig_out = model(x).
  3. graphdef, state = nnx.split(model); merged = nnx.merge(graphdef, state).
  4. merged_out = merged(x).
  5. abs_diff = jnp.max(jnp.abs(orig_out - merged_out)).
  6. Return jnp.array([float(orig_out.sum()), float(merged_out.sum()), float(abs_diff)]).

Expected: orig_out.sum() == merged_out.sum() and abs_diff == 0.0.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • features: int (passed as float).

Output: length-3 array [s, s, 0.0].

Hints

flax nnx split merge

Sign in to attempt this problem and view the solution.