medium primitives

NNX Optimizer Basics

Why this matters

In Linen, the canonical training-state container is flax.training.train_state.TrainState. It bundles apply_fn, params, tx (the optax GradientTransformation), and opt_state, and you thread a fresh state through every step via state = state.apply_gradients(grads=grads).

nnx replaces all of that with one object: nnx.Optimizer. It holds a reference to your model and the optax state, and it mutates the model in place. There is no TrainState, no apply_gradients returning a new container — the optimizer IS the training-state container.

The recipe

import jax, jax.numpy as jnp, optax
from flax import nnx

model = nnx.Linear(in_features, out_features, rngs=nnx.Rngs(seed))
optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param)

def loss_fn(model, x, y):
    pred = model(x)
    return jnp.mean((pred - y) ** 2)

loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
optimizer.update(model, grads)   # mutates the model in place

Three things to notice:

  1. nnx.value_and_grad(loss_fn)(model, x, y) — the model is the differentiation target, not a params pytree. Under the hood nnx walks the model, finds every nnx.Param, and returns a grads pytree shaped like the model’s param state.
  2. optimizer.update(model, grads) — the model is passed because the optimizer needs to know which variables to mutate. Returns nothing; the model’s params now hold the post-step values.
  3. wrt=nnx.Param — tells the optimizer which Variable types to track. For training, you almost always want nnx.Param. (If you had auxiliary state you wanted optax to track, you’d pass a broader filter; but BatchNorm running stats, for example, are NOT optax- managed — they live on the model and mutate themselves.)

Why “in place” works under JAX

JAX arrays are still immutable. The mutation happens on the nnx Variable wrapperparam.value = new_array rebinds the .value attribute. The wrapper’s identity is stable across steps; only what it points to changes. This is the same trick BatchNorm uses for running statistics.

Common pitfalls

  • Forgetting wrt=nnx.Param. As of Flax 0.11, wrt is required. Older code that omitted it will raise TypeError: Missing required argument 'wrt'.
  • Calling nnx.value_and_grad(loss_fn)(params, x, y). That’s the Linen pattern. In nnx, pass the model as the first argument.
  • Trying to assign the optimizer’s return value. optimizer.update returns None. The model is mutated in place; do not rebind.
  • Building a fresh optimizer every step. The optimizer holds optax state (momentum buffers, etc.). Build it once, reuse it across steps.

Problem

Implement optimizer_one_step(seed, x, y, lr):

  1. Build model = nnx.Linear(x.shape[-1], y.shape[-1], rngs=nnx.Rngs(int(seed))).
  2. Build optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param).
  3. Define loss_fn(model, x, y) -> mean((model(x) - y)**2).
  4. Compute loss_before = loss_fn(model, x, y).
  5. Compute grads via nnx.value_and_grad(loss_fn)(model, x, y).
  6. Call optimizer.update(model, grads).
  7. Compute loss_after = loss_fn(model, x, y).
  8. Return jnp.array([float(loss_before), float(loss_after)]).

For a sane learning rate the loss must strictly decrease.

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D_in).
  • y: 2-D (N, D_out).
  • lr: float.

Output: 1-D (2,)[loss_before, loss_after].

Hints

flax nnx optimizer training

Sign in to attempt this problem and view the solution.