We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
eval_step — Forward + Metrics
Why this matters
train_step and eval_step are siblings, not opposites. Both call the
same model on a batch. The difference is what comes after:
-
train_stepcomputes a loss, takes its gradient, and updates the optimizer state. -
eval_stepcomputes the loss (or any metric) and stops there. No gradient. No parameter update. No optimizer state at all.
Eval is a forward pass plus arithmetic. That’s it.
Why it’s a separate function
You COULD just inline the eval logic at the call site. People separate it because:
-
Determinism: eval should not include dropout or BatchNorm-style
running-stat updates. Wrapping it gives you one place to enforce that
(e.g.
apply(..., training=False)ormutable=False). -
JIT-ability: like
train_step, you’ll typicallyjax.jit(eval_step)so the inner work runs at XLA speed. - Reuse: dev/test/holdout loops all call the same function.
The recipe
def eval_step(state, batch):
x, y = batch
preds = state.apply_fn({"params": state.params}, x).reshape(-1)
mse = jnp.mean((preds - y) ** 2)
mae = jnp.mean(jnp.abs(preds - y))
return mse, mae
Notice what’s NOT here:
-
No
value_and_grad. -
No
apply_gradients. - No optimizer reference.
-
No
statemutation.
eval_step returns metrics; the caller logs them. The caller does NOT
rebind state from the result of eval — eval is read-only.
Common pitfalls
- Returning a scalar when a tuple was expected: get used to returning a small dict or tuple of named metrics; you’ll always want more than just MSE eventually.
-
Forgetting to set the model to eval mode if the model has dropout
or batchnorm; here, plain
Densehas neither so we don’t worry. -
Updating
state.stepin eval: don’t. Eval should never move the step counter. (apply_gradientsis the only thing that does.)
Problem
-
Build a
TrainStatewith a tinynn.Dense(1)andoptax.sgd(0.05). -
Run 3 training steps on
(x, y)to drift away from initialization. -
Define
eval_step(state, batch) → (mse, mae). -
Call it once on
(x, y)and return[mse, mae]as a 1-D(2,)array.
lr is fixed at 0.05 so we can compare runs.
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D). -
y: 1-D(N,).
Output: 1-D (2,) array [mse, mae].
Hints
flax
training
eval-step
Sign in to attempt this problem and view the solution.