medium primitives

Optax RMSprop Step

Why this matters

RMSprop (Root Mean Squared Propagation) was developed by Geoff Hinton for non-stationary problems. It is the optimizer of choice for recurrent networks and reinforcement learning, where the gradient distribution shifts rapidly during training.

How RMSprop works

RMSprop maintains a running mean of squared gradients per parameter:

v ← decay * v + (1 - decay) * grads²
params ← params - lr * grads / √(v + ε)

The denominator √(v + ε) rescales each gradient by the root of its recent squared magnitude, effectively normalising the update scale per-parameter without accumulating a first moment.

First-step behaviour

With v initialised to zero and decay=0.9, the first step is:

v = 0.1 * grads²
update = lr * grads / √(0.1 * grads² + ε)
       ≈ lr * sign(grads) / √0.1   (for |grads| >> √ε)
       ≈ lr * sign(grads) * 3.16

This is notably larger than an Adam first step, because RMSprop has no bias correction and no first-moment smoothing.

RMSprop vs Adam

Feature RMSprop Adam
1st moment (mean) No Yes (β₁)
2nd moment (variance) Yes (decay) Yes (β₂)
Bias correction No Yes
Typical use RL, RNNs General DL

Optax defaults

optax.rmsprop(lr) uses decay=0.9, eps=1e-8 by default.

Common pitfalls

  • No bias correction: the first-step update is larger than Adam’s.
  • No 1st moment: unlike Adam, there is no momentum term by default (pass momentum=... to add it).
  • Sensitive to decay: a higher decay gives a longer memory but can slow adaptation to distribution shifts.

Inputs

  • params: 1-D array of current parameter values.
  • grads: 1-D array of gradients, same shape as params.
  • lr: scalar learning rate.

Output

1-D array — updated parameter values after one RMSprop step.

Hints

optax rmsprop

Sign in to attempt this problem and view the solution.