medium primitives

NNX MultiMetric

Why this matters

Real evaluation reports averaged metrics over MANY batches, not just one. With 1000 eval samples and a batch size of 32 you have ~32 batches; you want one final MSE / accuracy / loss number that averages all of them.

Hand-rolling this is annoying: keep a running sum, keep a count, divide at the end, reset between epochs. Easy to get wrong (off-by- one count, forgetting to reset, mixing per-sample and per-batch averages).

nnx.MultiMetric is the canonical container. It bundles N named metrics, each tracks its own running average, and one .compute() call returns a dict with all of them.

The recipe

metrics = nnx.MultiMetric(
    mse=nnx.metrics.Average('mse'),
    mae=nnx.metrics.Average('mae'),
)

for x_batch, y_batch in batches:
    pred = model(x_batch)
    mse = jnp.mean((pred - y_batch) ** 2)
    mae = jnp.mean(jnp.abs(pred - y_batch))
    metrics.update(mse=mse, mae=mae)

result = metrics.compute()        # {"mse": ..., "mae": ...}
metrics.reset()                   # Zero the buffers between epochs.

The argname gotcha

nnx.metrics.Average() takes an optional positional argname — the keyword name update will look up. The default is 'values', so Average() only accepts metrics.update(values=...). To use custom names like mse= and mae=, you must pass the argname explicitly:

nnx.metrics.Average('mse')   # update(mse=...)
nnx.metrics.Average('mae')   # update(mae=...)

Forget this and you get TypeError: Expected keyword argument 'values'.

Why a Module-style metric?

MultiMetric is itself an nnx Module — its running sums and counts are nnx.Variables on the metric object. That means the same “wrapper provides apparent mutation” trick from BatchNorm applies: metrics.update(...) rebinds the Variable values in place; under jit and scan the metric is a regular pytree leaf.

Common pitfalls

  • Average() with the default 'values'. Then your update kwarg must be values=.... Pass Average('your_name') to alias it.
  • Per-sample vs per-batch averaging. Average averages whatever you pass to update. If you pass jnp.mean(losses_in_batch), you’re averaging per-batch values — fine if all batches have the same size, slightly biased otherwise.
  • Forgetting to reset between epochs. Running averages accumulate across all update calls until reset().
  • Extracting result["mse"] as a Python float without float(...). It’s a JAX array. Either keep it as a jnp.array or cast.

Problem

Implement multimetric_aggregate(seed, x_batches_flat, y_batches_flat, batch_size, lr):

  1. Build model = nnx.Linear(D_in, D_out, rngs=...) and optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param). Train the model for 2 SGD steps on the full flat batch with MSE loss (so the model has learned a little before we measure).
  2. Construct metrics = nnx.MultiMetric(mse=nnx.metrics.Average('mse'), mae=nnx.metrics.Average('mae')).
  3. Slice x_batches_flat into num_batches = N // batch_size chunks of size batch_size. For each chunk, compute MSE and MAE on the model’s prediction and call metrics.update(mse=..., mae=...).
  4. result = metrics.compute(). Return jnp.array([float(result["mse"]), float(result["mae"])]).

Inputs:

  • seed: float (cast to int).
  • x_batches_flat: 2-D (N, D_in). Slice into batches.
  • y_batches_flat: 2-D (N, D_out).
  • batch_size: float (cast to int). Divides N exactly.
  • lr: float.

Output: 1-D (2,)[avg_mse, avg_mae].

Hints

flax nnx metrics multimetric

Sign in to attempt this problem and view the solution.