medium primitives

NNX State Transform

Why this matters

Once you have an nnx model split into (graphdef, state), the state is a plain JAX pytree of arrays — and pytrees compose with everything in jax.tree_util. Need to zero a layer’s kernel? tree_map. Add a perturbation? tree_map. Cast every leaf to bf16? tree_map. Apply a selective edit by name? tree_map_with_path.

The pattern: split → transform state with tree-map → merge → call. This is the heart of model surgery (zeroing layers for ablation), fine-tuning (loading a base, then patching a head), and various debugging maneuvers (perturbation analysis).

The split/merge round-trip is purely structural; once you write a correct transform once, the same code works on any nnx module.

API: jax.tree_util.tree_map_with_path

def transform(path, leaf):
    # `path` is a tuple of path elements that names this leaf
    # `leaf` is the underlying array
    name = '.'.join(str(p.key) if hasattr(p, 'key') else str(p) for p in path)
    if 'kernel' in name:
        return jnp.zeros_like(leaf)
    return leaf

new_state = jax.tree_util.tree_map_with_path(transform, state)

tree_map_with_path walks the pytree and calls transform(path, leaf) for every leaf. The path is a tuple of structured path elements; the one-liner above flattens them into a string for easy substring matching.

For an nnx.Linear, the leaves live at paths ('kernel', '.value') and ('bias', '.value') (or similar — exact form depends on the version). The string-match approach is robust to minor structural changes.

Worked example

model = nnx.Linear(4, 6, rngs=nnx.Rngs(0))
model.bias.value = jnp.ones((6,))            # set bias to all ones for visibility

graphdef, state = nnx.split(model)
# state: {'kernel': nnx.Param(<4x6>), 'bias': nnx.Param(<6,>)}

def zero_if_kernel(path, leaf):
    name = '.'.join(str(p.key) if hasattr(p, 'key') else str(p) for p in path)
    return jnp.zeros_like(leaf) if 'kernel' in name else leaf

new_state = jax.tree_util.tree_map_with_path(zero_if_kernel, state)
merged = nnx.merge(graphdef, new_state)
print(merged(jnp.array([1.0, 2.0, 3.0, 4.0])))   # [1, 1, 1, 1, 1, 1]
# bias all-ones survived; kernel is zeros so x @ kernel = 0

The output is just the bias (all ones) regardless of x — kernel was zeroed, so x @ kernel is 0.

Compare with naive tree_map

jax.tree_util.tree_map(lambda l: l, state) with no path information would let you transform every leaf the same way (e.g., lambda l: l * 0.0). But we want to zero ONLY the kernel; we need the path. That’s why tree_map_with_path exists.

Alternative idiom: split into params + buffers via nnx.state(model, type) and only transform the params. Useful when you want type-based selection rather than name-based.

Common pitfalls

  • Forgetting _with_path. Plain tree_map doesn’t pass the path; you can’t tell kernel from bias.
  • Using tree_map to mutate the model in place. It returns a NEW pytree; you must merge to get an updated model.
  • Path-element shape varies. A SequenceKey (list index) has .idx, a DictKey has .key. The getattr(..., 'key', ...) defensive read in the example handles both.
  • Forgetting to merge. tree_map_with_path(state) returns a pytree, not a callable.

Problem

Write state_zero_kernel(seed, x, features):

  1. Build model = nnx.Linear(in_features=x.shape[-1], out_features=int(features), rngs=nnx.Rngs(int(seed))).
  2. Set the bias to all ones: model.bias.value = jnp.ones((int(features),)). (This makes the post-zero output observable.)
  3. Split: graphdef, state = nnx.split(model).
  4. Define a zero_if_kernel(path, leaf) that zeros leaves whose path contains 'kernel', leaves all others unchanged.
  5. Apply: new_state = jax.tree_util.tree_map_with_path(zero_if_kernel, state).
  6. Merge and forward: merged = nnx.merge(graphdef, new_state); return merged(x).

Expected output: an all-ones array of length features (the bias is untouched, the kernel was zeroed → x @ 0 + 1 = 1).

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • features: int (passed as float).

Output: length-features array of ones.

Hints

flax nnx tree-map surgery

Sign in to attempt this problem and view the solution.