We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
easy
primitives
Optax SGD with Momentum
Why this matters
Momentum accelerates SGD by accumulating a velocity vector that smooths out oscillations and speeds convergence along consistent gradient directions. It is the default choice for training CNNs and many other architectures.
How momentum works
The update rule maintains a velocity v across steps:
v ← momentum * v + grads
params ← params - lr * v
Because velocity starts at zero, the first step is identical to plain SGD:
v = 0 * 0 + grads = grads
params_new = params - lr * grads
Optax usage
Pass momentum as a keyword argument to optax.sgd:
optimizer = optax.sgd(lr, momentum=momentum)
opt_state = optimizer.init(params)
updates, new_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
Common pitfalls
-
momentum=0produces plain SGD — there is no accumulation. -
The velocity lives inside
opt_state— it does not persist unless you threadnew_statebetween steps. -
Setting
momentum=0.9is the standard “heavy-ball” setting used in most deep learning frameworks.
Inputs
-
params: 1-D array of current parameter values. -
grads: 1-D array of gradients, same shape asparams. -
lr: scalar learning rate. -
momentum: scalar momentum coefficient (e.g.0.9).
Output
1-D array — updated parameter values after one momentum-SGD step.
Hints
optax
sgd
momentum
Sign in to attempt this problem and view the solution.