easy primitives

Optax SGD Step

Why this matters

Optax is JAXโ€™s standard optimizer library. Rather than hand-rolling gradient update rules, Optax provides composable, stateless transformations that plug cleanly into any JAX training loop. SGD is the simplest entry point.

The Optax pattern

Every Optax optimizer follows the same three-step protocol:

optimizer = optax.sgd(lr)            # 1. create the optimizer
opt_state = optimizer.init(params)   # 2. initialise state from params
updates, new_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)

Optax never mutates arrays โ€” it returns brand-new arrays at every step. The opt_state carries any running statistics (for SGD without momentum it is essentially empty).

What SGD does

Plain SGD computes updates = -lr * grads, so:

new_params = params - lr * grads

For params = [1.0, 2.0, 3.0], grads = [0.1, 0.2, 0.3], lr = 0.5:

new_params = [1.0 - 0.05, 2.0 - 0.10, 3.0 - 0.15]
           = [0.95, 1.90, 2.85]

Common pitfalls

  • Forgetting optax.apply_updates โ€” the optimizer returns updates, not new params directly.
  • Calling optimizer.update(grads, opt_state) without params โ€” some transforms (e.g. weight decay) need access to the current parameters.
  • Re-using the same opt_state across steps โ€” always use new_state returned from optimizer.update.

Inputs

  • params: 1-D array of current parameter values.
  • grads: 1-D array of gradients, same shape as params.
  • lr: scalar learning rate.

Output

1-D array โ€” updated parameter values after one SGD step.

Hints

optax sgd

Sign in to attempt this problem and view the solution.