medium primitives

NNX Train Step

Why this matters

A “training step” is the smallest reusable unit of training: forward, loss, backward, optimizer update. Once you have it, the rest of the loop is just: for batch in data: train_step(model, optimizer, *batch).

In Linen the canonical step looks like:

def train_step(state, x, y):
    def loss_fn(params):
        preds = state.apply_fn({"params": params}, x)
        return jnp.mean((preds - y) ** 2)
    grads = jax.grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)   # NEW state
    return state, loss

Notice the state round-trip — you must rebind it on every step, because TrainState is a frozen pytree. The function returns the updated state along with the loss.

In nnx the same step is structurally simpler:

def train_step(model, optimizer, x, y):
    loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
    optimizer.update(model, grads)
    return loss

There’s nothing to thread. The optimizer mutates the model in place, so the next call sees the post-step params. The function only needs to return the loss for logging.

Why this matters in practice

The Linen pattern leaks one detail into every loop: state = train_step(state, ...). Forgetting that state = is a classic bug — your training silently doesn’t learn because the new state is discarded.

The nnx pattern can’t have that bug. There is no “new state” to drop on the floor.

The recipe

model = nnx.Linear(in_features, out_features, rngs=nnx.Rngs(seed))
optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param)

def loss_fn(model, x, y):
    return jnp.mean((model(x) - y) ** 2)

def train_step(model, optimizer, x, y):
    loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
    optimizer.update(model, grads)
    return loss

losses = []
for _ in range(num_steps):
    losses.append(float(train_step(model, optimizer, x, y)))

For this problem we run the same (x, y) batch over and over — a sane learning rate makes the loss decrease monotonically.

Common pitfalls

  • Building a fresh optimizer inside the loop. Then optax’s momentum buffers reset every step. Build once, reuse.
  • Capturing model from an outer scope inside loss_fn. Works, but obscures the differentiation target. Pass the model as the first arg of loss_fn so nnx.value_and_grad can find it.
  • Forgetting wrt=nnx.Param. Required since Flax 0.11.
  • Returning (model, optimizer, loss) from train_step. Habit from the Linen pattern. In nnx this is unnecessary; nothing changed about which Python object is which.

Problem

Implement train_step_loss(seed, x, y, lr, num_steps):

  1. Build model = nnx.Linear(x.shape[-1], y.shape[-1], rngs=nnx.Rngs(int(seed))).
  2. Build optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param).
  3. Define loss_fn(model, x, y) -> mean((model(x) - y)**2).
  4. Define train_step(model, optimizer, x, y) that calls nnx.value_and_grad(loss_fn), then optimizer.update, returns the loss.
  5. Loop int(num_steps) times calling train_step on the same (x, y). Record each step’s loss.
  6. Return jnp.array([loss_step_0, loss_step_final]) as 1-D (2,).

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D_in).
  • y: 2-D (N, D_out).
  • lr: float.
  • num_steps: float (cast to int).

Output: 1-D (2,)[loss_step_0, loss_step_final].

Hints

flax nnx training train-step

Sign in to attempt this problem and view the solution.