We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Optax SGD Step
Why this matters
Optax is JAXโs standard optimizer library. Rather than hand-rolling gradient update rules, Optax provides composable, stateless transformations that plug cleanly into any JAX training loop. SGD is the simplest entry point.
The Optax pattern
Every Optax optimizer follows the same three-step protocol:
optimizer = optax.sgd(lr) # 1. create the optimizer
opt_state = optimizer.init(params) # 2. initialise state from params
updates, new_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
Optax never mutates arrays โ it returns brand-new arrays at every step.
The opt_state carries any running statistics (for SGD without momentum
it is essentially empty).
What SGD does
Plain SGD computes updates = -lr * grads, so:
new_params = params - lr * grads
For params = [1.0, 2.0, 3.0], grads = [0.1, 0.2, 0.3], lr = 0.5:
new_params = [1.0 - 0.05, 2.0 - 0.10, 3.0 - 0.15]
= [0.95, 1.90, 2.85]
Common pitfalls
-
Forgetting
optax.apply_updatesโ the optimizer returns updates, not new params directly. -
Calling
optimizer.update(grads, opt_state)withoutparamsโ some transforms (e.g. weight decay) need access to the current parameters. -
Re-using the same
opt_stateacross steps โ always usenew_statereturned fromoptimizer.update.
Inputs
-
params: 1-D array of current parameter values. -
grads: 1-D array of gradients, same shape asparams. -
lr: scalar learning rate.
Output
1-D array โ updated parameter values after one SGD step.
Hints
Sign in to attempt this problem and view the solution.