medium primitives

NNX Eval Step

Why this matters

Training and evaluation use the same model. The only difference is that during evaluation you do NOT update parameters — you just call the forward and compute metrics.

In Linen this distinction is awkward because state is the training container and at eval time you only need the params subset. People often write a separate eval_step(state, ...) that pulls state.params out and runs the apply manually.

In nnx the model object is the same in both regimes. Eval is just the forward pass plus a metric — no optimizer involvement, no state threading.

The recipe

# Train phase: standard.
model = nnx.Linear(in_features, out_features, rngs=nnx.Rngs(seed))
optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param)

for _ in range(num_steps):
    _, grads = nnx.value_and_grad(loss_fn)(model, x_train, y_train)
    optimizer.update(model, grads)

# Eval phase: just call the model. The optimizer plays no role.
pred_eval = model(x_eval)
mse = jnp.mean((pred_eval - y_eval) ** 2)
mae = jnp.mean(jnp.abs(pred_eval - y_eval))

The model object is identical. After training, its .kernel.value and .bias.value hold whatever the optimizer left there.

What about train/eval mode?

Some layers (Dropout, BatchNorm) behave differently in train vs eval. nnx handles this via a flag passed at call time (or via model.eval() / model.train() in newer APIs). For a plain nnx.Linear there is no train/eval distinction — the forward is the same — so for this problem you don’t need a flag. We’ll cover the train/eval flag in a later problem.

Common pitfalls

  • Building the optimizer at eval time. It’s pointless and adds noise to the code. The optimizer is only needed for update.
  • Calling nnx.value_and_grad during eval. That builds the backward graph for nothing. Just call model(x_eval) directly.
  • Forgetting that the SAME model has been mutated. After training, the model’s params have moved. Calling model(x_eval) uses those updated params — that’s the entire point.
  • Using a different model for eval. Don’t. The params live on the training model object; build a new one and you’ve lost the training.

Problem

Implement eval_metrics(seed, x_train, y_train, x_eval, y_eval, lr):

  1. Build model = nnx.Linear(x_train.shape[-1], y_train.shape[-1], rngs=nnx.Rngs(int(seed))).
  2. Build optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param).
  3. Train for 3 SGD steps on (x_train, y_train) with MSE loss.
  4. After training, compute on the eval set:
    • pred_eval = model(x_eval)
    • mse = jnp.mean((pred_eval - y_eval) ** 2)
    • mae = jnp.mean(jnp.abs(pred_eval - y_eval))
  5. Return jnp.array([float(mse), float(mae)]) as 1-D (2,).

Inputs:

  • seed: float (cast to int).
  • x_train: 2-D (N_train, D_in).
  • y_train: 2-D (N_train, D_out).
  • x_eval: 2-D (N_eval, D_in).
  • y_eval: 2-D (N_eval, D_out).
  • lr: float.

Output: 1-D (2,)[mse_eval, mae_eval].

Hints

flax nnx training eval

Sign in to attempt this problem and view the solution.