We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Exponential Moving Average
Implement an exponential moving average (EMA) update for model parameters.
What is EMA and why is it used?
An exponential moving average maintains a smoothed copy of model weights that changes slowly over training. Rather than reading weights directly from the latest checkpoint, you read from the EMA copy โ which averages over the last ~1/(1-decay) training steps. This produces more stable evaluation numbers and often better downstream performance.
This technique appears under several names:
- Polyak averaging โ the general idea of averaging iterates
- EMA teacher โ in BYOL, MoCo, and semi-supervised learning, a frozen EMA copy acts as a target network
- Stochastic weight averaging (SWA) โ a variant that averages less frequently
- EMA in diffusion models โ the published model weights are often the EMA copy, not the raw training weights
The math
After each training step, update the EMA copy with:
ema_new = decay * ema_old + (1 - decay) * current
This is a convex combination: decay controls how much weight to give the
historical EMA versus the freshly-updated weights.
Choosing decay
A typical value is decay = 0.999, which gives the EMA roughly a
1000-step memory window (1 / (1 - 0.999) โ 1000). Larger decay โ
slower update โ longer memory.
Special cases:
-
decay = 1.0โ EMA never changes (frozen) -
decay = 0.0โ EMA always equals the current weights (no smoothing)
Inputs
-
ema_weights: shape(d,)โ running EMA values -
current_weights: shape(d,)โ current model weights after the latest optimizer step -
decay: float in[0, 1]โ smoothing factor, typically 0.999
Output
Shape (d,) โ the updated EMA weights.
Hints
Sign in to attempt this problem and view the solution.