medium primitives

EMA of Parameters

Why this matters

The parameters at the end of training are usually NOT the parameters you want to ship. They’re the noisy snapshot at the last SGD step. What you really want is something more like the average parameters over the last several thousand steps — a temporal smoothing of the optimization trajectory.

The standard tool is an Exponential Moving Average (EMA) of the weights. After every optimizer step, you maintain a separate copy ema of the params and update it with:

new_ema = decay * ema + (1 - decay) * current

At inference time, you use ema instead of current. Empirically this gives:

  • More stable predictions (less variance run-to-run).
  • Often a small but real boost in accuracy.
  • Critical for diffusion models (DDPM, EDM): the EMA params are the ones that produce high-quality samples; the raw training params tend to produce noisy garbage.

Self-supervised methods (BYOL, MoCo, SimSiam) use EMA params as the “target network” to bootstrap learning. Test-time stability is everything there.

Choosing decay

The decay controls how much “memory” the EMA has. The effective averaging window is roughly 1 / (1 - decay) steps:

decay effective window
0.9 ~10 steps
0.99 ~100 steps
0.999 ~1000 steps
0.9999 ~10,000 steps

Diffusion models typically use 0.999 or 0.9999. Self-supervised methods (BYOL) use 0.996-0.999, sometimes ramped up over training. Too LOW and you get the noisy training params back. Too HIGH and the EMA lags behind real progress and never catches up.

The math (why it’s “exponential”)

Unrolling the recurrence:

ema_t = decay * ema_{t-1} + (1 - decay) * cur_t
      = decay^t * ema_0 + (1-decay) * sum_{i=0}^{t-1} decay^i * cur_{t-i}

The contribution of cur_{t-i} decays as decay^i — exponentially. Recent params dominate; ancient params fade.

Pitfalls

  • Bias at start: when ema_0 = current_0, the EMA matches the current params. But if you initialize ema_0 = 0 (some impls do), the early EMA is biased toward zero. Adam-style bias correction ema_hat = ema / (1 - decay^t) is the textbook fix.
  • Don’t EMA the optimizer state. Only EMA the params themselves.
  • EMA on a copy: ema is a separate buffer; do NOT replace current with ema mid-training, or you’ll destroy your gradients.

Problem

Given two flat 1-D parameter arrays — params_a_flat (the EMA) and params_b_flat (the current params) — compute the next EMA value:

new_ema = decay * params_a_flat + (1 - decay) * params_b_flat

Return the result as a 1-D array of the same length.

Inputs:

  • seed: float (cast to int, unused — kept for signature consistency).
  • params_a_flat: 1-D (K,) — current EMA params.
  • params_b_flat: 1-D (K,) — current “live” params.
  • decay: float in (0, 1).

Output: 1-D (K,) — the updated EMA buffer.

Hints

flax training ema

Sign in to attempt this problem and view the solution.