medium primitives

TrainState — One Step

Why this matters

A real Flax training loop has at least four moving parts: the model’s apply function, the parameters, the optimizer, and the optimizer’s own state (momentum buffers, step counter, etc). Threading all four through every function by hand is verbose and error-prone.

flax.training.train_state.TrainState is the canonical container that bundles them together. One object goes in, one (updated) object comes out. Almost every Flax example you read in the wild uses it.

What TrainState gives you

TrainState.create(apply_fn, params, tx) returns a frozen container with fields:

  • apply_fn: the model’s apply (so you can call state.apply_fn({"params": p}, x)).
  • params: the current parameter pytree.
  • tx: the Optax GradientTransformation.
  • opt_state: the Optax internal state (created automatically from tx.init(params)).
  • step: integer, starts at 0; auto-incremented by apply_gradients.

The key method is state.apply_gradients(grads=grads), which:

  1. Calls tx.update(grads, state.opt_state, state.params) to compute parameter updates.
  2. Applies them with optax.apply_updates(state.params, updates).
  3. Replaces state.opt_state with the new opt_state.
  4. Increments state.step by 1.
  5. Returns a new TrainState (it’s immutable; the old one is unchanged).

The recipe

from flax.training import train_state
import optax

model = TinyDense()
rng = jax.random.PRNGKey(seed)
params = model.init(rng, x)["params"]
tx = optax.sgd(lr)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

def loss_fn(params):
    preds = state.apply_fn({"params": params}, x).reshape(-1)
    return jnp.mean((preds - y) ** 2)

grads = jax.grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads)   # step is now 1
loss_after = loss_fn(state.params)

Common pitfalls

  • state.apply_gradients(grads=...) is keyword-only — apply_gradients(grads) raises.
  • state.params is read-only; you can’t mutate it. Re-bind state = state.apply_gradients(...).
  • The grad must be a pytree with the same structure as state.params. If you compute grads w.r.t. some other variable, the optimizer will complain.
  • step starts at 0 and increments AFTER apply_gradients, so after one call it’s 1.

Problem

Build a TrainState around a nn.Dense(1) model and optax.sgd(lr). Run one training step on (x, y) using MSE loss, then return:

[float(state.step), float(loss_after_first_step)]

as a 1-D (2,) array. After one apply_gradients call, state.step == 1.

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D) features.
  • y: 1-D (N,) targets.
  • lr: float — SGD learning rate.

Output: 1-D (2,) array [step_after, loss_after].

Hints

flax training train-state

Sign in to attempt this problem and view the solution.