medium primitives

eval_step — Forward + Metrics

Why this matters

train_step and eval_step are siblings, not opposites. Both call the same model on a batch. The difference is what comes after:

  • train_step computes a loss, takes its gradient, and updates the optimizer state.
  • eval_step computes the loss (or any metric) and stops there. No gradient. No parameter update. No optimizer state at all.

Eval is a forward pass plus arithmetic. That’s it.

Why it’s a separate function

You COULD just inline the eval logic at the call site. People separate it because:

  1. Determinism: eval should not include dropout or BatchNorm-style running-stat updates. Wrapping it gives you one place to enforce that (e.g. apply(..., training=False) or mutable=False).
  2. JIT-ability: like train_step, you’ll typically jax.jit(eval_step) so the inner work runs at XLA speed.
  3. Reuse: dev/test/holdout loops all call the same function.

The recipe

def eval_step(state, batch):
    x, y = batch
    preds = state.apply_fn({"params": state.params}, x).reshape(-1)
    mse = jnp.mean((preds - y) ** 2)
    mae = jnp.mean(jnp.abs(preds - y))
    return mse, mae

Notice what’s NOT here:

  • No value_and_grad.
  • No apply_gradients.
  • No optimizer reference.
  • No state mutation.

eval_step returns metrics; the caller logs them. The caller does NOT rebind state from the result of eval — eval is read-only.

Common pitfalls

  • Returning a scalar when a tuple was expected: get used to returning a small dict or tuple of named metrics; you’ll always want more than just MSE eventually.
  • Forgetting to set the model to eval mode if the model has dropout or batchnorm; here, plain Dense has neither so we don’t worry.
  • Updating state.step in eval: don’t. Eval should never move the step counter. (apply_gradients is the only thing that does.)

Problem

  1. Build a TrainState with a tiny nn.Dense(1) and optax.sgd(0.05).
  2. Run 3 training steps on (x, y) to drift away from initialization.
  3. Define eval_step(state, batch) → (mse, mae).
  4. Call it once on (x, y) and return [mse, mae] as a 1-D (2,) array.

lr is fixed at 0.05 so we can compare runs.

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D).
  • y: 1-D (N,).

Output: 1-D (2,) array [mse, mae].

Hints

flax training eval-step

Sign in to attempt this problem and view the solution.