We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
TrainState — One Step
Why this matters
A real Flax training loop has at least four moving parts: the model’s
apply function, the parameters, the optimizer, and the optimizer’s
own state (momentum buffers, step counter, etc). Threading all four
through every function by hand is verbose and error-prone.
flax.training.train_state.TrainState is the canonical container that
bundles them together. One object goes in, one (updated) object comes
out. Almost every Flax example you read in the wild uses it.
What TrainState gives you
TrainState.create(apply_fn, params, tx) returns a frozen container with
fields:
-
apply_fn: the model’sapply(so you can callstate.apply_fn({"params": p}, x)). -
params: the current parameter pytree. -
tx: the OptaxGradientTransformation. -
opt_state: the Optax internal state (created automatically fromtx.init(params)). -
step: integer, starts at 0; auto-incremented byapply_gradients.
The key method is state.apply_gradients(grads=grads), which:
-
Calls
tx.update(grads, state.opt_state, state.params)to compute parameter updates. -
Applies them with
optax.apply_updates(state.params, updates). -
Replaces
state.opt_statewith the new opt_state. -
Increments
state.stepby 1. -
Returns a new
TrainState(it’s immutable; the old one is unchanged).
The recipe
from flax.training import train_state
import optax
model = TinyDense()
rng = jax.random.PRNGKey(seed)
params = model.init(rng, x)["params"]
tx = optax.sgd(lr)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
def loss_fn(params):
preds = state.apply_fn({"params": params}, x).reshape(-1)
return jnp.mean((preds - y) ** 2)
grads = jax.grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads) # step is now 1
loss_after = loss_fn(state.params)
Common pitfalls
-
state.apply_gradients(grads=...)is keyword-only —apply_gradients(grads)raises. -
state.paramsis read-only; you can’t mutate it. Re-bindstate = state.apply_gradients(...). -
The grad must be a pytree with the same structure as
state.params. If you compute grads w.r.t. some other variable, the optimizer will complain. -
stepstarts at 0 and increments AFTERapply_gradients, so after one call it’s 1.
Problem
Build a TrainState around a nn.Dense(1) model and optax.sgd(lr).
Run one training step on (x, y) using MSE loss, then return:
[float(state.step), float(loss_after_first_step)]
as a 1-D (2,) array. After one apply_gradients call, state.step == 1.
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D)features. -
y: 1-D(N,)targets. -
lr: float — SGD learning rate.
Output: 1-D (2,) array [step_after, loss_after].
Hints
Sign in to attempt this problem and view the solution.