medium primitives

NNX Pure State Update

Why this matters

nnx is mutable in eager mode (problem 5), but everywhere serious it goes pure-functional: optimizers, jit, vmap, distributed sharding all operate on state pytrees, not on Module objects. The canonical pattern for any state-changing op is:

  1. graphdef, state = nnx.split(model)
  2. compute new_state from state purely (no side effects)
  3. model = nnx.merge(graphdef, new_state) (or, more commonly, the merged model is built INSIDE a transformed function)

This problem walks through the simplest pure update: scale every parameter by a scalar scale. It looks trivial — but it’s the same shape as an Optax update step (new_params = params - lr * grads), a parameter EMA update (problem 16), or an FP16 cast.

Once you can write this update, you’ve written the inner loop of every nnx training step.

API: jax.tree_util.tree_map across state

graphdef, state = nnx.split(model)

new_state = jax.tree_util.tree_map(
    lambda leaf: leaf * scale,    # applies to every array
    state,
)

merged = nnx.merge(graphdef, new_state)

tree_map(fn, state) walks the state pytree and replaces each leaf with fn(leaf). The structure is preserved; the variable wrappers (nnx.Param, nnx.Variable) are reattached automatically when nnx flattens the state on merge.

Notice this update is uniform: every leaf gets multiplied. For selective updates (e.g., only Params), use nnx.state(model, nnx.Param) to pick the slice first, then tree_map, then nnx.update(model, new_subset) to write back.

Worked example

model = nnx.Linear(3, 2, rngs=nnx.Rngs(0))
print(model.kernel.value)
# [[ 0.5,  0.3], [ 0.1, -0.2], [...]]

graphdef, state = nnx.split(model)
new_state = jax.tree_util.tree_map(lambda l: l * 2.0, state)
merged = nnx.merge(graphdef, new_state)

print(merged.kernel.value)
# [[1.0, 0.6], [0.2, -0.4], [...]]   — twice the original
print(model.kernel.value)
# unchanged — the update is pure; original model is untouched

Pure means PURE: the original model isn’t modified. The merged model is a NEW Python object with scaled params.

Where you’ll see this in real training

In Sprint C3 (nnx.optimizer-basics), the train step is roughly:

def train_step(model, optimizer, x, y):
    loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
    optimizer.update(grads)             # Optax updates model in place via the optimizer
    return loss

Internally, Optax does exactly the tree-map pattern: split the model, compute the parameter update from the gradients, merge back. The optimizer hides the bookkeeping; you ride on top of it.

Common pitfalls

  • Mutating state in place. State pytrees are conceptually immutable. Use tree_map (or assign to a new variable).
  • Forgetting to merge. A scaled state pytree isn’t callable.
  • Using += scale. JAX arrays are immutable. Use multiply (or addition) and reassign.
  • Casting scale to a tracer too early. For Eager-mode pure updates with a Python float, just do lambda l: l * float(scale).

Problem

Write pure_state_update(seed, x, features, scale):

  1. Build model = nnx.Linear(in_features=x.shape[-1], out_features=int(features), rngs=nnx.Rngs(int(seed))).
  2. Split: graphdef, state = nnx.split(model).
  3. Apply: new_state = jax.tree_util.tree_map(lambda leaf: leaf * float(scale), state). (This multiplies BOTH kernel and bias by scale.)
  4. Merge: merged = nnx.merge(graphdef, new_state).
  5. Return merged(x).

Note: nnx.Linear‘s default bias is zero, so multiplying it by anything is still zero. The visible effect is the kernel scaling, which scales x @ kernel linearly.

Inputs:

  • seed: int (passed as float).
  • x: 1-D JAX array.
  • features: int (passed as float).
  • scale: float — multiplier for every state leaf.

Output: 1-D array of length features(x @ (scale * kernel)) + (scale * bias).

Hints

flax nnx tree-map pure-update

Sign in to attempt this problem and view the solution.