We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX State Transform
Why this matters
Once you have an nnx model split into (graphdef, state), the state is a
plain JAX pytree of arrays — and pytrees compose with everything in
jax.tree_util. Need to zero a layer’s kernel? tree_map. Add a
perturbation? tree_map. Cast every leaf to bf16? tree_map. Apply a
selective edit by name? tree_map_with_path.
The pattern: split → transform state with tree-map → merge → call. This is the heart of model surgery (zeroing layers for ablation), fine-tuning (loading a base, then patching a head), and various debugging maneuvers (perturbation analysis).
The split/merge round-trip is purely structural; once you write a correct transform once, the same code works on any nnx module.
API: jax.tree_util.tree_map_with_path
def transform(path, leaf):
# `path` is a tuple of path elements that names this leaf
# `leaf` is the underlying array
name = '.'.join(str(p.key) if hasattr(p, 'key') else str(p) for p in path)
if 'kernel' in name:
return jnp.zeros_like(leaf)
return leaf
new_state = jax.tree_util.tree_map_with_path(transform, state)
tree_map_with_path walks the pytree and calls transform(path, leaf)
for every leaf. The path is a tuple of structured path elements; the
one-liner above flattens them into a string for easy substring matching.
For an nnx.Linear, the leaves live at paths
('kernel', '.value') and ('bias', '.value') (or similar — exact form
depends on the version). The string-match approach is robust to
minor structural changes.
Worked example
model = nnx.Linear(4, 6, rngs=nnx.Rngs(0))
model.bias.value = jnp.ones((6,)) # set bias to all ones for visibility
graphdef, state = nnx.split(model)
# state: {'kernel': nnx.Param(<4x6>), 'bias': nnx.Param(<6,>)}
def zero_if_kernel(path, leaf):
name = '.'.join(str(p.key) if hasattr(p, 'key') else str(p) for p in path)
return jnp.zeros_like(leaf) if 'kernel' in name else leaf
new_state = jax.tree_util.tree_map_with_path(zero_if_kernel, state)
merged = nnx.merge(graphdef, new_state)
print(merged(jnp.array([1.0, 2.0, 3.0, 4.0]))) # [1, 1, 1, 1, 1, 1]
# bias all-ones survived; kernel is zeros so x @ kernel = 0
The output is just the bias (all ones) regardless of x — kernel was
zeroed, so x @ kernel is 0.
Compare with naive tree_map
jax.tree_util.tree_map(lambda l: l, state) with no path information
would let you transform every leaf the same way (e.g., lambda l: l * 0.0).
But we want to zero ONLY the kernel; we need the path. That’s why
tree_map_with_path exists.
Alternative idiom: split into params + buffers via nnx.state(model, type)
and only transform the params. Useful when you want type-based selection
rather than name-based.
Common pitfalls
-
Forgetting
_with_path. Plaintree_mapdoesn’t pass the path; you can’t tell kernel from bias. -
Using
tree_mapto mutate the model in place. It returns a NEW pytree; you must merge to get an updated model. -
Path-element shape varies. A SequenceKey (list index) has
.idx, a DictKey has.key. Thegetattr(..., 'key', ...)defensive read in the example handles both. -
Forgetting to
merge.tree_map_with_path(state)returns a pytree, not a callable.
Problem
Write state_zero_kernel(seed, x, features):
-
Build
model = nnx.Linear(in_features=x.shape[-1], out_features=int(features), rngs=nnx.Rngs(int(seed))). -
Set the bias to all ones:
model.bias.value = jnp.ones((int(features),)). (This makes the post-zero output observable.) -
Split:
graphdef, state = nnx.split(model). -
Define a
zero_if_kernel(path, leaf)that zeros leaves whose path contains'kernel', leaves all others unchanged. -
Apply:
new_state = jax.tree_util.tree_map_with_path(zero_if_kernel, state). -
Merge and forward:
merged = nnx.merge(graphdef, new_state); return merged(x).
Expected output: an all-ones array of length features (the bias is
untouched, the kernel was zeroed → x @ 0 + 1 = 1).
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
features: int (passed as float).
Output: length-features array of ones.
Hints
Sign in to attempt this problem and view the solution.