medium primitives

Optax EMA on Params

Why this matters

Exponential Moving Average (EMA) of model parameters is ubiquitous in modern deep learning: EMA-averaged weights for target networks in RL (DQN, DDPG), self-supervised learning (MoCo, BYOL), and model-weight smoothing (PolyAK averaging). Understanding the one-line formula is foundational before reaching for any library wrapper.

The recipe

EMA is defined as:

ema_new = decay * ema_old + (1 - decay) * new_value

Applied to model params:

ema_params = decay * params + (1 - decay) * new_params

A high decay (e.g. 0.99) makes the EMA slow to react; a low decay (e.g. 0.5) makes it react quickly.

Common pitfalls

  • Swapping the roles of decay and (1 - decay): the old value is multiplied by decay, not the new one.
  • optax.ema(decay) exists as a gradient transform for smoothing gradient estimates, not raw parameter EMA. For params-EMA, apply the formula directly.

Inputs

  • params: 1-D JAX array โ€” the current (old) EMA estimate.
  • new_params: 1-D JAX array โ€” the new observation to blend in.
  • decay: scalar float โ€” how much weight to give the old estimate.

Output

1-D array โ€” updated EMA after one blend step.

Hints

optax ema

Sign in to attempt this problem and view the solution.