medium primitives

Optax Adam Step

Why this matters

Adam (Adaptive Moment Estimation) is the default optimizer for most modern deep learning. It combines momentum (1st moment) with per-parameter adaptive learning rates (2nd moment), making it far more robust to poorly scaled gradients than SGD.

The Adam update

Adam maintains two exponential moving averages of the gradients:

m ← β₁ * m + (1 - β₁) * grads       # 1st moment (mean)
v ← β₂ * v + (1 - β₂) * grads²      # 2nd moment (variance)

Bias-corrected estimates are then used to scale the update:

m̂ = m / (1 - β₁ᵗ)
v̂ = v / (1 - β₂ᵗ)
params ← params - lr * m̂ / (√v̂ + ε)

At step t=1 with β₁=0.9, β₂=0.999, the update approximates params - lr * sign(grads) — a near-unit step in gradient direction.

Optax defaults

optax.adam(lr) uses b1=0.9, b2=0.999, eps=1e-8 by default. This problem fixes lr=1e-2.

Common pitfalls

  • Bias correction matters early — without it the first steps are very small (β₁ᵗ → 1 for small t).
  • Adam step magnitude ≠ SGD step magnitude — the adaptive scaling means the effective update is roughly lr, not lr * grad_magnitude.
  • eps prevents division by zero — for sparse gradients (many zeros) this keeps training numerically stable.

Inputs

  • params: 1-D array of current parameter values.
  • grads: 1-D array of gradients, same shape as params.
  • Learning rate is fixed at 1e-2 inside the function.

Output

1-D array — updated parameter values after one Adam step.

Hints

optax adam

Sign in to attempt this problem and view the solution.