We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Eval Step
Why this matters
Training and evaluation use the same model. The only difference is that during evaluation you do NOT update parameters — you just call the forward and compute metrics.
In Linen this distinction is awkward because state is the training
container and at eval time you only need the params subset. People
often write a separate eval_step(state, ...) that pulls
state.params out and runs the apply manually.
In nnx the model object is the same in both regimes. Eval is just the forward pass plus a metric — no optimizer involvement, no state threading.
The recipe
# Train phase: standard.
model = nnx.Linear(in_features, out_features, rngs=nnx.Rngs(seed))
optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param)
for _ in range(num_steps):
_, grads = nnx.value_and_grad(loss_fn)(model, x_train, y_train)
optimizer.update(model, grads)
# Eval phase: just call the model. The optimizer plays no role.
pred_eval = model(x_eval)
mse = jnp.mean((pred_eval - y_eval) ** 2)
mae = jnp.mean(jnp.abs(pred_eval - y_eval))
The model object is identical. After training, its .kernel.value
and .bias.value hold whatever the optimizer left there.
What about train/eval mode?
Some layers (Dropout, BatchNorm) behave differently in train vs eval.
nnx handles this via a flag passed at call time (or via model.eval()
/ model.train() in newer APIs). For a plain nnx.Linear there is
no train/eval distinction — the forward is the same — so for this
problem you don’t need a flag. We’ll cover the train/eval flag in a
later problem.
Common pitfalls
-
Building the optimizer at eval time. It’s pointless and adds
noise to the code. The optimizer is only needed for
update. -
Calling
nnx.value_and_gradduring eval. That builds the backward graph for nothing. Just callmodel(x_eval)directly. -
Forgetting that the SAME model has been mutated. After
training, the model’s params have moved. Calling
model(x_eval)uses those updated params — that’s the entire point. - Using a different model for eval. Don’t. The params live on the training model object; build a new one and you’ve lost the training.
Problem
Implement eval_metrics(seed, x_train, y_train, x_eval, y_eval, lr):
-
Build
model = nnx.Linear(x_train.shape[-1], y_train.shape[-1], rngs=nnx.Rngs(int(seed))). -
Build
optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param). -
Train for 3 SGD steps on
(x_train, y_train)with MSE loss. -
After training, compute on the eval set:
-
pred_eval = model(x_eval) -
mse = jnp.mean((pred_eval - y_eval) ** 2) -
mae = jnp.mean(jnp.abs(pred_eval - y_eval))
-
-
Return
jnp.array([float(mse), float(mae)])as 1-D(2,).
Inputs:
-
seed: float (cast to int). -
x_train: 2-D(N_train, D_in). -
y_train: 2-D(N_train, D_out). -
x_eval: 2-D(N_eval, D_in). -
y_eval: 2-D(N_eval, D_out). -
lr: float.
Output: 1-D (2,) — [mse_eval, mae_eval].
Hints
Sign in to attempt this problem and view the solution.