We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Optax Adam Step
Why this matters
Adam (Adaptive Moment Estimation) is the default optimizer for most modern deep learning. It combines momentum (1st moment) with per-parameter adaptive learning rates (2nd moment), making it far more robust to poorly scaled gradients than SGD.
The Adam update
Adam maintains two exponential moving averages of the gradients:
m ← β₁ * m + (1 - β₁) * grads # 1st moment (mean)
v ← β₂ * v + (1 - β₂) * grads² # 2nd moment (variance)
Bias-corrected estimates are then used to scale the update:
m̂ = m / (1 - β₁ᵗ)
v̂ = v / (1 - β₂ᵗ)
params ← params - lr * m̂ / (√v̂ + ε)
At step t=1 with β₁=0.9, β₂=0.999, the update approximates
params - lr * sign(grads) — a near-unit step in gradient direction.
Optax defaults
optax.adam(lr) uses b1=0.9, b2=0.999, eps=1e-8 by default. This
problem fixes lr=1e-2.
Common pitfalls
- Bias correction matters early — without it the first steps are very small (β₁ᵗ → 1 for small t).
-
Adam step magnitude ≠ SGD step magnitude — the adaptive scaling
means the effective update is roughly
lr, notlr * grad_magnitude. -
epsprevents division by zero — for sparse gradients (many zeros) this keeps training numerically stable.
Inputs
-
params: 1-D array of current parameter values. -
grads: 1-D array of gradients, same shape asparams. -
Learning rate is fixed at
1e-2inside the function.
Output
1-D array — updated parameter values after one Adam step.
Hints
Sign in to attempt this problem and view the solution.