We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
hard
primitives
Full Training Step (Loss + Grad + Update)
Why this matters
The canonical Optax training step glues four primitives together: define a loss function, differentiate it, ask the optimizer for parameter updates, and apply those updates. Mastering the pattern once means you can extend it to any architecture.
The recipe
def loss_fn(w):
return jnp.sum((w * x + b - y) ** 2)
optimizer = optax.sgd(0.01)
opt_state = optimizer.init(w)
grads = jax.grad(loss_fn)(w)
updates, _= optimizer.update(grads, opt_state, w)
new_w = optax.apply_updates(w, updates)
final_loss = loss_fn(new_w)
return jnp.array([new_w, final_loss])
In production code you would return (new_params, new_opt_state, loss).
Here, for testability, pack [final_w, final_loss] into a 1-D array.
Common pitfalls
-
The gradient is computed at the old
w, not the updated one. -
The final loss is evaluated at the new
wafter the update. -
optax.sgd(lr)wraps momentum=0 by default โ momentum is an extra arg. -
optimizer.updatereturns(updates, new_opt_state)โ always unpack both.
Inputs
-
w: scalar โ initial weight. -
b: scalar โ bias (fixed; not updated). -
x: 1-D JAX array โ input features. -
y: 1-D JAX array โ targets.
Output
1-D array of shape (2,) โ [final_w, final_loss] after one SGD step
with learning rate 0.01.
Hints
optax
training
full-step
Sign in to attempt this problem and view the solution.