We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Train Step
Why this matters
A “training step” is the smallest reusable unit of training: forward,
loss, backward, optimizer update. Once you have it, the rest of the
loop is just: for batch in data: train_step(model, optimizer, *batch).
In Linen the canonical step looks like:
def train_step(state, x, y):
def loss_fn(params):
preds = state.apply_fn({"params": params}, x)
return jnp.mean((preds - y) ** 2)
grads = jax.grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads) # NEW state
return state, loss
Notice the state round-trip — you must rebind it on every step,
because TrainState is a frozen pytree. The function returns the
updated state along with the loss.
In nnx the same step is structurally simpler:
def train_step(model, optimizer, x, y):
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
optimizer.update(model, grads)
return loss
There’s nothing to thread. The optimizer mutates the model in place, so the next call sees the post-step params. The function only needs to return the loss for logging.
Why this matters in practice
The Linen pattern leaks one detail into every loop: state = train_step(state, ...). Forgetting that state = is a classic
bug — your training silently doesn’t learn because the new state
is discarded.
The nnx pattern can’t have that bug. There is no “new state” to drop on the floor.
The recipe
model = nnx.Linear(in_features, out_features, rngs=nnx.Rngs(seed))
optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param)
def loss_fn(model, x, y):
return jnp.mean((model(x) - y) ** 2)
def train_step(model, optimizer, x, y):
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
optimizer.update(model, grads)
return loss
losses = []
for _ in range(num_steps):
losses.append(float(train_step(model, optimizer, x, y)))
For this problem we run the same (x, y) batch over and over —
a sane learning rate makes the loss decrease monotonically.
Common pitfalls
-
Building a fresh
optimizerinside the loop. Then optax’s momentum buffers reset every step. Build once, reuse. -
Capturing
modelfrom an outer scope insideloss_fn. Works, but obscures the differentiation target. Pass the model as the first arg ofloss_fnsonnx.value_and_gradcan find it. -
Forgetting
wrt=nnx.Param. Required since Flax 0.11. -
Returning
(model, optimizer, loss)fromtrain_step. Habit from the Linen pattern. In nnx this is unnecessary; nothing changed about which Python object is which.
Problem
Implement train_step_loss(seed, x, y, lr, num_steps):
-
Build
model = nnx.Linear(x.shape[-1], y.shape[-1], rngs=nnx.Rngs(int(seed))). -
Build
optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param). -
Define
loss_fn(model, x, y) -> mean((model(x) - y)**2). -
Define
train_step(model, optimizer, x, y)that callsnnx.value_and_grad(loss_fn), thenoptimizer.update, returns the loss. -
Loop
int(num_steps)times callingtrain_stepon the same(x, y). Record each step’s loss. -
Return
jnp.array([loss_step_0, loss_step_final])as 1-D(2,).
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D_in). -
y: 2-D(N, D_out). -
lr: float. -
num_steps: float (cast to int).
Output: 1-D (2,) — [loss_step_0, loss_step_final].
Hints
Sign in to attempt this problem and view the solution.