We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Optimizer Basics
Why this matters
In Linen, the canonical training-state container is
flax.training.train_state.TrainState. It bundles apply_fn, params,
tx (the optax GradientTransformation), and opt_state, and you
thread a fresh state through every step via
state = state.apply_gradients(grads=grads).
nnx replaces all of that with one object: nnx.Optimizer. It holds
a reference to your model and the optax state, and it mutates the
model in place. There is no TrainState, no apply_gradients returning
a new container — the optimizer IS the training-state container.
The recipe
import jax, jax.numpy as jnp, optax
from flax import nnx
model = nnx.Linear(in_features, out_features, rngs=nnx.Rngs(seed))
optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param)
def loss_fn(model, x, y):
pred = model(x)
return jnp.mean((pred - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
optimizer.update(model, grads) # mutates the model in place
Three things to notice:
-
nnx.value_and_grad(loss_fn)(model, x, y)— the model is the differentiation target, not aparamspytree. Under the hood nnx walks the model, finds everynnx.Param, and returns a grads pytree shaped like the model’s param state. -
optimizer.update(model, grads)— the model is passed because the optimizer needs to know which variables to mutate. Returns nothing; the model’s params now hold the post-step values. -
wrt=nnx.Param— tells the optimizer whichVariabletypes to track. For training, you almost always wantnnx.Param. (If you had auxiliary state you wanted optax to track, you’d pass a broader filter; but BatchNorm running stats, for example, are NOT optax- managed — they live on the model and mutate themselves.)
Why “in place” works under JAX
JAX arrays are still immutable. The mutation happens on the nnx
Variable wrapper — param.value = new_array rebinds the .value
attribute. The wrapper’s identity is stable across steps; only what
it points to changes. This is the same trick BatchNorm uses for
running statistics.
Common pitfalls
-
Forgetting
wrt=nnx.Param. As of Flax 0.11,wrtis required. Older code that omitted it will raiseTypeError: Missing required argument 'wrt'. -
Calling
nnx.value_and_grad(loss_fn)(params, x, y). That’s the Linen pattern. In nnx, pass the model as the first argument. -
Trying to assign the optimizer’s return value.
optimizer.updatereturnsNone. The model is mutated in place; do not rebind. - Building a fresh optimizer every step. The optimizer holds optax state (momentum buffers, etc.). Build it once, reuse it across steps.
Problem
Implement optimizer_one_step(seed, x, y, lr):
-
Build
model = nnx.Linear(x.shape[-1], y.shape[-1], rngs=nnx.Rngs(int(seed))). -
Build
optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param). -
Define
loss_fn(model, x, y) -> mean((model(x) - y)**2). -
Compute
loss_before = loss_fn(model, x, y). -
Compute grads via
nnx.value_and_grad(loss_fn)(model, x, y). -
Call
optimizer.update(model, grads). -
Compute
loss_after = loss_fn(model, x, y). -
Return
jnp.array([float(loss_before), float(loss_after)]).
For a sane learning rate the loss must strictly decrease.
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D_in). -
y: 2-D(N, D_out). -
lr: float.
Output: 1-D (2,) — [loss_before, loss_after].
Hints
Sign in to attempt this problem and view the solution.