medium primitives

train_step with value_and_grad

Why this matters

The previous problem (flax-train-state) introduced TrainState. This one is the next inevitable step: wrap one optimization update in a function called train_step(state, batch). That’s the unit you’ll eventually jax.jit and call inside a Python for loop over batches.

Mastering this 6-line function is mastering training in JAX.

The pattern

def loss_fn(params, batch):
    x, y = batch
    preds = state.apply_fn({"params": params}, x).reshape(-1)
    return jnp.mean((preds - y) ** 2)

def train_step(state, batch):
    loss, grads = jax.value_and_grad(loss_fn)(state.params, batch)
    state = state.apply_gradients(grads=grads)
    return state, loss

Note: value_and_grad returns the loss at the input params (before the update). That’s correct — you can’t compute the gradient at one point and the loss at another. If you want the post-update loss for logging, do another forward pass.

Three primitives glued together:

  1. jax.value_and_grad(f): returns a function that computes both f(x) and grad f(x) in one pass — strictly cheaper than two separate calls.
  2. state.apply_gradients(grads=grads): optimizer + parameter update.
  3. Return both the new state AND the loss — you need the loss for logs.

Why value_and_grad and not grad?

jax.grad(f) ONLY returns the gradient. To also get the loss for logging, you’d have to call f(x) again — wasted forward pass. jax.value_and_grad(f) returns (f(x), grad f(x)) from a single forward+backward.

Why a closure on state.apply_fn?

loss_fn doesn’t take apply_fn as an argument because it’s captured from the enclosing scope. Reading state.apply_fn from inside loss_fn treats it as a static Python value — JAX won’t trace through it as a leaf. (Captured Python functions are fine; captured arrays are also fine, but they get “baked in” to the trace.)

Common pitfalls

  • Differentiating w.r.t. the wrong arg: jax.value_and_grad(loss_fn) defaults to argnums=0, so the FIRST positional arg is the one you differentiate against. Pass params first.
  • Forgetting .reshape(-1) when Dense(1) returns (N, 1) and y is (N,): silent broadcasting bug; (N, 1) - (N,) → (N, N). Always flatten or expand-dims explicitly.
  • Returning grads instead of loss: easy in tutorials. The function returns (state, loss), in that order, by convention.

Problem

Build a TrainState exactly as in pos 59, define loss_fn(params, batch) and train_step(state, batch) per the recipe above, call it once on the given (x, y), and return the loss as a 1-D (1,) array.

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D).
  • y: 1-D (N,).
  • lr: float.

Output: 1-D (1,)[loss].

Hints

flax training value-and-grad

Sign in to attempt this problem and view the solution.