We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
train_step with value_and_grad
Why this matters
The previous problem (flax-train-state) introduced TrainState. This
one is the next inevitable step: wrap one optimization update in a
function called train_step(state, batch). That’s the unit you’ll
eventually jax.jit and call inside a Python for loop over batches.
Mastering this 6-line function is mastering training in JAX.
The pattern
def loss_fn(params, batch):
x, y = batch
preds = state.apply_fn({"params": params}, x).reshape(-1)
return jnp.mean((preds - y) ** 2)
def train_step(state, batch):
loss, grads = jax.value_and_grad(loss_fn)(state.params, batch)
state = state.apply_gradients(grads=grads)
return state, loss
Note: value_and_grad returns the loss at the input params (before
the update). That’s correct — you can’t compute the gradient at one
point and the loss at another. If you want the post-update loss for
logging, do another forward pass.
Three primitives glued together:
-
jax.value_and_grad(f): returns a function that computes bothf(x)andgrad f(x)in one pass — strictly cheaper than two separate calls. -
state.apply_gradients(grads=grads): optimizer + parameter update. - Return both the new state AND the loss — you need the loss for logs.
Why value_and_grad and not grad?
jax.grad(f) ONLY returns the gradient. To also get the loss for
logging, you’d have to call f(x) again — wasted forward pass.
jax.value_and_grad(f) returns (f(x), grad f(x)) from a single
forward+backward.
Why a closure on state.apply_fn?
loss_fn doesn’t take apply_fn as an argument because it’s captured
from the enclosing scope. Reading state.apply_fn from inside loss_fn
treats it as a static Python value — JAX won’t trace through it as a
leaf. (Captured Python functions are fine; captured arrays are also
fine, but they get “baked in” to the trace.)
Common pitfalls
-
Differentiating w.r.t. the wrong arg:
jax.value_and_grad(loss_fn)defaults toargnums=0, so the FIRST positional arg is the one you differentiate against. Passparamsfirst. -
Forgetting
.reshape(-1)whenDense(1)returns(N, 1)andyis(N,): silent broadcasting bug;(N, 1) - (N,) → (N, N). Always flatten or expand-dims explicitly. -
Returning grads instead of loss: easy in tutorials. The function
returns
(state, loss), in that order, by convention.
Problem
Build a TrainState exactly as in pos 59, define loss_fn(params, batch)
and train_step(state, batch) per the recipe above, call it once on the
given (x, y), and return the loss as a 1-D (1,) array.
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D). -
y: 1-D(N,). -
lr: float.
Output: 1-D (1,) — [loss].
Hints
Sign in to attempt this problem and view the solution.