We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Running Average / EMA Step
Why this matters
Exponential Moving Averages of model parameters are used everywhere:
- Polyak averaging for SGD-trained models — improves test accuracy “for free” with no architectural change.
-
Diffusion models keep an EMA of the network weights for sampling,
typically with
decay = 0.999or0.9999. -
Self-supervised methods (BYOL, MoCo, DINO) drive a target network
via EMA of the online network, often
decay = 0.996+. - PPO/A3C target nets use EMA for stable bootstrapping.
The update rule is one line:
ema_t = decay * ema_{t-1} + (1 - decay) * params_t
Higher decay → slower update, smoother trajectory. decay = 0.9999
means the EMA forgets old values on a timescale of ~10000 steps.
This problem is a pure-JAX warm-up to the EMA concept — the linear
blend of two arrays. In nnx, the natural framing is a tiny module with
an ema_weights nnx.Variable updated in __call__:
class EMAWrapper(nnx.Module):
def __init__(self, init_params):
self.ema_weights = nnx.Variable(init_params)
def update(self, params, decay):
self.ema_weights.value = (
decay * self.ema_weights.value + (1 - decay) * params
)
return self.ema_weights.value
The pure-JAX function in this problem captures the math; later problems package it as a Module method.
Worked example
ema = jnp.zeros((4,))
params = jnp.array([1.0, 2.0, 3.0, 4.0])
decay = 0.9
new_ema = decay * ema + (1 - decay) * params
# = 0.9 * 0 + 0.1 * params
# = [0.1, 0.2, 0.3, 0.4]
Apply repeatedly and the EMA approaches params from below.
Bias correction note
A common practical refinement is bias-correcting the EMA with
1 - decay**t, since starting from ema = 0 underestimates early on.
Adam does this for its first/second-moment estimators. We don’t do it
here — pure plain EMA — but it’s a one-line modification.
In nnx specifically
Once you put this update inside a Module:
self.ema_weights.value = decay * self.ema_weights.value + (1 - decay) * params
is the entire training-time update. NO mutable=[...] flag, no return-tuple.
Just attribute mutation. (Inside nnx.jit, the Variable update is lifted
via the same split/merge machinery from problem 12.)
Common pitfalls
-
Wrong sign of the blend. The formula is
decay * ema + (1 - decay) * params, not the reverse. Higher decay means more weight on the old value. -
decay = 1.0. The EMA never updates. Avoid (use 0.999, 0.9999). -
decay = 0.0. EMA equals current params; no smoothing. -
Casting
decay. It arrives as a float — usefloat(decay)to be explicit. -
Treating EMA as a parameter. It’s NOT trainable; gradients shouldn’t
flow through it. In nnx, wrap as
nnx.Variable, notnnx.Param.
Problem
Write ema_step(seed, params_arr, ema_arr, decay):
Compute and return decay * ema_arr + (1.0 - decay) * params_arr.
seed is unused (kept in the signature so the harness shape matches the
rest of the track). Cast decay with float(decay).
Inputs:
-
seed: int (passed as float; unused). -
params_arr: 1-D JAX array — current parameter values. -
ema_arr: 1-D JAX array — current EMA estimate (same shape asparams_arr). -
decay: float in [0, 1) — typically 0.99, 0.999, or higher.
Output: 1-D array of same shape as params_arr.
Hints
Sign in to attempt this problem and view the solution.