medium primitives

Exponential Moving Average

Implement an exponential moving average (EMA) update for model parameters.

What is EMA and why is it used?

An exponential moving average maintains a smoothed copy of model weights that changes slowly over training. Rather than reading weights directly from the latest checkpoint, you read from the EMA copy โ€” which averages over the last ~1/(1-decay) training steps. This produces more stable evaluation numbers and often better downstream performance.

This technique appears under several names:

  • Polyak averaging โ€” the general idea of averaging iterates
  • EMA teacher โ€” in BYOL, MoCo, and semi-supervised learning, a frozen EMA copy acts as a target network
  • Stochastic weight averaging (SWA) โ€” a variant that averages less frequently
  • EMA in diffusion models โ€” the published model weights are often the EMA copy, not the raw training weights

The math

After each training step, update the EMA copy with:

ema_new = decay * ema_old + (1 - decay) * current

This is a convex combination: decay controls how much weight to give the historical EMA versus the freshly-updated weights.

Choosing decay

A typical value is decay = 0.999, which gives the EMA roughly a 1000-step memory window (1 / (1 - 0.999) โ‰ˆ 1000). Larger decay โ†’ slower update โ†’ longer memory.

Special cases:

  • decay = 1.0 โ€” EMA never changes (frozen)
  • decay = 0.0 โ€” EMA always equals the current weights (no smoothing)

Inputs

  • ema_weights: shape (d,) โ€” running EMA values
  • current_weights: shape (d,) โ€” current model weights after the latest optimizer step
  • decay: float in [0, 1] โ€” smoothing factor, typically 0.999

Output

Shape (d,) โ€” the updated EMA weights.

Hints

training ema

Sign in to attempt this problem and view the solution.