We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Pure State Update
Why this matters
nnx is mutable in eager mode (problem 5), but everywhere serious it goes
pure-functional: optimizers, jit, vmap, distributed sharding all operate
on state pytrees, not on Module objects. The canonical pattern for any
state-changing op is:
-
graphdef, state = nnx.split(model) -
compute
new_statefromstatepurely (no side effects) -
model = nnx.merge(graphdef, new_state)(or, more commonly, the merged model is built INSIDE a transformed function)
This problem walks through the simplest pure update: scale every
parameter by a scalar scale. It looks trivial — but it’s the same
shape as an Optax update step (new_params = params - lr * grads),
a parameter EMA update (problem 16), or an FP16 cast.
Once you can write this update, you’ve written the inner loop of every nnx training step.
API: jax.tree_util.tree_map across state
graphdef, state = nnx.split(model)
new_state = jax.tree_util.tree_map(
lambda leaf: leaf * scale, # applies to every array
state,
)
merged = nnx.merge(graphdef, new_state)
tree_map(fn, state) walks the state pytree and replaces each leaf with
fn(leaf). The structure is preserved; the variable wrappers (nnx.Param,
nnx.Variable) are reattached automatically when nnx flattens the state
on merge.
Notice this update is uniform: every leaf gets multiplied. For
selective updates (e.g., only Params), use nnx.state(model, nnx.Param)
to pick the slice first, then tree_map, then nnx.update(model, new_subset)
to write back.
Worked example
model = nnx.Linear(3, 2, rngs=nnx.Rngs(0))
print(model.kernel.value)
# [[ 0.5, 0.3], [ 0.1, -0.2], [...]]
graphdef, state = nnx.split(model)
new_state = jax.tree_util.tree_map(lambda l: l * 2.0, state)
merged = nnx.merge(graphdef, new_state)
print(merged.kernel.value)
# [[1.0, 0.6], [0.2, -0.4], [...]] — twice the original
print(model.kernel.value)
# unchanged — the update is pure; original model is untouched
Pure means PURE: the original model isn’t modified. The merged model is a NEW Python object with scaled params.
Where you’ll see this in real training
In Sprint C3 (nnx.optimizer-basics), the train step is roughly:
def train_step(model, optimizer, x, y):
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
optimizer.update(grads) # Optax updates model in place via the optimizer
return loss
Internally, Optax does exactly the tree-map pattern: split the model, compute the parameter update from the gradients, merge back. The optimizer hides the bookkeeping; you ride on top of it.
Common pitfalls
-
Mutating
statein place. State pytrees are conceptually immutable. Usetree_map(or assign to a new variable). - Forgetting to merge. A scaled state pytree isn’t callable.
-
Using
+= scale. JAX arrays are immutable. Use multiply (or addition) and reassign. -
Casting
scaleto a tracer too early. For Eager-mode pure updates with a Python float, just dolambda l: l * float(scale).
Problem
Write pure_state_update(seed, x, features, scale):
-
Build
model = nnx.Linear(in_features=x.shape[-1], out_features=int(features), rngs=nnx.Rngs(int(seed))). -
Split:
graphdef, state = nnx.split(model). -
Apply:
new_state = jax.tree_util.tree_map(lambda leaf: leaf * float(scale), state). (This multiplies BOTH kernel and bias byscale.) -
Merge:
merged = nnx.merge(graphdef, new_state). -
Return
merged(x).
Note: nnx.Linear‘s default bias is zero, so multiplying it by anything
is still zero. The visible effect is the kernel scaling, which scales
x @ kernel linearly.
Inputs:
-
seed: int (passed as float). -
x: 1-D JAX array. -
features: int (passed as float). -
scale: float — multiplier for every state leaf.
Output: 1-D array of length features — (x @ (scale * kernel)) + (scale * bias).
Hints
Sign in to attempt this problem and view the solution.