easy primitives

Optax SGD with Momentum

Why this matters

Momentum accelerates SGD by accumulating a velocity vector that smooths out oscillations and speeds convergence along consistent gradient directions. It is the default choice for training CNNs and many other architectures.

How momentum works

The update rule maintains a velocity v across steps:

v ← momentum * v + grads
params ← params - lr * v

Because velocity starts at zero, the first step is identical to plain SGD:

v = 0 * 0 + grads = grads
params_new = params - lr * grads

Optax usage

Pass momentum as a keyword argument to optax.sgd:

optimizer = optax.sgd(lr, momentum=momentum)
opt_state = optimizer.init(params)
updates, new_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)

Common pitfalls

  • momentum=0 produces plain SGD — there is no accumulation.
  • The velocity lives inside opt_state — it does not persist unless you thread new_state between steps.
  • Setting momentum=0.9 is the standard “heavy-ball” setting used in most deep learning frameworks.

Inputs

  • params: 1-D array of current parameter values.
  • grads: 1-D array of gradients, same shape as params.
  • lr: scalar learning rate.
  • momentum: scalar momentum coefficient (e.g. 0.9).

Output

1-D array — updated parameter values after one momentum-SGD step.

Hints

optax sgd momentum

Sign in to attempt this problem and view the solution.