We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
EMA of Parameters
Why this matters
The parameters at the end of training are usually NOT the parameters you want to ship. They’re the noisy snapshot at the last SGD step. What you really want is something more like the average parameters over the last several thousand steps — a temporal smoothing of the optimization trajectory.
The standard tool is an Exponential Moving Average (EMA) of the
weights. After every optimizer step, you maintain a separate copy
ema of the params and update it with:
new_ema = decay * ema + (1 - decay) * current
At inference time, you use ema instead of current. Empirically
this gives:
- More stable predictions (less variance run-to-run).
- Often a small but real boost in accuracy.
- Critical for diffusion models (DDPM, EDM): the EMA params are the ones that produce high-quality samples; the raw training params tend to produce noisy garbage.
Self-supervised methods (BYOL, MoCo, SimSiam) use EMA params as the “target network” to bootstrap learning. Test-time stability is everything there.
Choosing decay
The decay controls how much “memory” the EMA has. The effective
averaging window is roughly 1 / (1 - decay) steps:
| decay | effective window |
|---|---|
| 0.9 | ~10 steps |
| 0.99 | ~100 steps |
| 0.999 | ~1000 steps |
| 0.9999 | ~10,000 steps |
Diffusion models typically use 0.999 or 0.9999. Self-supervised methods (BYOL) use 0.996-0.999, sometimes ramped up over training. Too LOW and you get the noisy training params back. Too HIGH and the EMA lags behind real progress and never catches up.
The math (why it’s “exponential”)
Unrolling the recurrence:
ema_t = decay * ema_{t-1} + (1 - decay) * cur_t
= decay^t * ema_0 + (1-decay) * sum_{i=0}^{t-1} decay^i * cur_{t-i}
The contribution of cur_{t-i} decays as decay^i — exponentially.
Recent params dominate; ancient params fade.
Pitfalls
-
Bias at start: when
ema_0 = current_0, the EMA matches the current params. But if you initializeema_0 = 0(some impls do), the early EMA is biased toward zero. Adam-style bias correctionema_hat = ema / (1 - decay^t)is the textbook fix. - Don’t EMA the optimizer state. Only EMA the params themselves.
-
EMA on a copy:
emais a separate buffer; do NOT replacecurrentwithemamid-training, or you’ll destroy your gradients.
Problem
Given two flat 1-D parameter arrays — params_a_flat (the EMA) and
params_b_flat (the current params) — compute the next EMA value:
new_ema = decay * params_a_flat + (1 - decay) * params_b_flat
Return the result as a 1-D array of the same length.
Inputs:
-
seed: float (cast to int, unused — kept for signature consistency). -
params_a_flat: 1-D(K,)— current EMA params. -
params_b_flat: 1-D(K,)— current “live” params. -
decay: float in(0, 1).
Output: 1-D (K,) — the updated EMA buffer.
Hints
Sign in to attempt this problem and view the solution.