hard primitives

Full Training Step (Loss + Grad + Update)

Why this matters

The canonical Optax training step glues four primitives together: define a loss function, differentiate it, ask the optimizer for parameter updates, and apply those updates. Mastering the pattern once means you can extend it to any architecture.

The recipe

def loss_fn(w):
    return jnp.sum((w * x + b - y) ** 2)

optimizer = optax.sgd(0.01)
opt_state = optimizer.init(w)
grads     = jax.grad(loss_fn)(w)
updates, _= optimizer.update(grads, opt_state, w)
new_w     = optax.apply_updates(w, updates)
final_loss = loss_fn(new_w)
return jnp.array([new_w, final_loss])

In production code you would return (new_params, new_opt_state, loss). Here, for testability, pack [final_w, final_loss] into a 1-D array.

Common pitfalls

  • The gradient is computed at the old w, not the updated one.
  • The final loss is evaluated at the new w after the update.
  • optax.sgd(lr) wraps momentum=0 by default โ€” momentum is an extra arg.
  • optimizer.update returns (updates, new_opt_state) โ€” always unpack both.

Inputs

  • w: scalar โ€” initial weight.
  • b: scalar โ€” bias (fixed; not updated).
  • x: 1-D JAX array โ€” input features.
  • y: 1-D JAX array โ€” targets.

Output

1-D array of shape (2,) โ€” [final_w, final_loss] after one SGD step with learning rate 0.01.

Hints

optax training full-step

Sign in to attempt this problem and view the solution.