medium primitives

NNX Running Average / EMA Step

Why this matters

Exponential Moving Averages of model parameters are used everywhere:

  • Polyak averaging for SGD-trained models — improves test accuracy “for free” with no architectural change.
  • Diffusion models keep an EMA of the network weights for sampling, typically with decay = 0.999 or 0.9999.
  • Self-supervised methods (BYOL, MoCo, DINO) drive a target network via EMA of the online network, often decay = 0.996+.
  • PPO/A3C target nets use EMA for stable bootstrapping.

The update rule is one line:

ema_t = decay * ema_{t-1} + (1 - decay) * params_t

Higher decay → slower update, smoother trajectory. decay = 0.9999 means the EMA forgets old values on a timescale of ~10000 steps.

This problem is a pure-JAX warm-up to the EMA concept — the linear blend of two arrays. In nnx, the natural framing is a tiny module with an ema_weights nnx.Variable updated in __call__:

class EMAWrapper(nnx.Module):
    def __init__(self, init_params):
        self.ema_weights = nnx.Variable(init_params)

    def update(self, params, decay):
        self.ema_weights.value = (
            decay * self.ema_weights.value + (1 - decay) * params
        )
        return self.ema_weights.value

The pure-JAX function in this problem captures the math; later problems package it as a Module method.

Worked example

ema = jnp.zeros((4,))
params = jnp.array([1.0, 2.0, 3.0, 4.0])
decay = 0.9
new_ema = decay * ema + (1 - decay) * params
# = 0.9 * 0 + 0.1 * params
# = [0.1, 0.2, 0.3, 0.4]

Apply repeatedly and the EMA approaches params from below.

Bias correction note

A common practical refinement is bias-correcting the EMA with 1 - decay**t, since starting from ema = 0 underestimates early on. Adam does this for its first/second-moment estimators. We don’t do it here — pure plain EMA — but it’s a one-line modification.

In nnx specifically

Once you put this update inside a Module:

self.ema_weights.value = decay * self.ema_weights.value + (1 - decay) * params

is the entire training-time update. NO mutable=[...] flag, no return-tuple. Just attribute mutation. (Inside nnx.jit, the Variable update is lifted via the same split/merge machinery from problem 12.)

Common pitfalls

  • Wrong sign of the blend. The formula is decay * ema + (1 - decay) * params, not the reverse. Higher decay means more weight on the old value.
  • decay = 1.0. The EMA never updates. Avoid (use 0.999, 0.9999).
  • decay = 0.0. EMA equals current params; no smoothing.
  • Casting decay. It arrives as a float — use float(decay) to be explicit.
  • Treating EMA as a parameter. It’s NOT trainable; gradients shouldn’t flow through it. In nnx, wrap as nnx.Variable, not nnx.Param.

Problem

Write ema_step(seed, params_arr, ema_arr, decay):

Compute and return decay * ema_arr + (1.0 - decay) * params_arr.

seed is unused (kept in the signature so the harness shape matches the rest of the track). Cast decay with float(decay).

Inputs:

  • seed: int (passed as float; unused).
  • params_arr: 1-D JAX array — current parameter values.
  • ema_arr: 1-D JAX array — current EMA estimate (same shape as params_arr).
  • decay: float in [0, 1) — typically 0.99, 0.999, or higher.

Output: 1-D array of same shape as params_arr.

Hints

jax nnx ema running-average

Sign in to attempt this problem and view the solution.