We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX MultiMetric
Why this matters
Real evaluation reports averaged metrics over MANY batches, not just one. With 1000 eval samples and a batch size of 32 you have ~32 batches; you want one final MSE / accuracy / loss number that averages all of them.
Hand-rolling this is annoying: keep a running sum, keep a count, divide at the end, reset between epochs. Easy to get wrong (off-by- one count, forgetting to reset, mixing per-sample and per-batch averages).
nnx.MultiMetric is the canonical container. It bundles N named
metrics, each tracks its own running average, and one .compute()
call returns a dict with all of them.
The recipe
metrics = nnx.MultiMetric(
mse=nnx.metrics.Average('mse'),
mae=nnx.metrics.Average('mae'),
)
for x_batch, y_batch in batches:
pred = model(x_batch)
mse = jnp.mean((pred - y_batch) ** 2)
mae = jnp.mean(jnp.abs(pred - y_batch))
metrics.update(mse=mse, mae=mae)
result = metrics.compute() # {"mse": ..., "mae": ...}
metrics.reset() # Zero the buffers between epochs.
The argname gotcha
nnx.metrics.Average() takes an optional positional argname —
the keyword name update will look up. The default is 'values',
so Average() only accepts metrics.update(values=...). To use
custom names like mse= and mae=, you must pass the argname
explicitly:
nnx.metrics.Average('mse') # update(mse=...)
nnx.metrics.Average('mae') # update(mae=...)
Forget this and you get TypeError: Expected keyword argument 'values'.
Why a Module-style metric?
MultiMetric is itself an nnx Module — its running sums and counts
are nnx.Variables on the metric object. That means the same
“wrapper provides apparent mutation” trick from BatchNorm applies:
metrics.update(...) rebinds the Variable values in place; under
jit and scan the metric is a regular pytree leaf.
Common pitfalls
-
Average()with the default'values'. Then your update kwarg must bevalues=.... PassAverage('your_name')to alias it. -
Per-sample vs per-batch averaging.
Averageaverages whatever you pass toupdate. If you passjnp.mean(losses_in_batch), you’re averaging per-batch values — fine if all batches have the same size, slightly biased otherwise. -
Forgetting to reset between epochs. Running averages accumulate
across all
updatecalls untilreset(). -
Extracting
result["mse"]as a Python float withoutfloat(...). It’s a JAX array. Either keep it as ajnp.arrayor cast.
Problem
Implement multimetric_aggregate(seed, x_batches_flat, y_batches_flat, batch_size, lr):
-
Build
model = nnx.Linear(D_in, D_out, rngs=...)andoptimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param). Train the model for 2 SGD steps on the full flat batch with MSE loss (so the model has learned a little before we measure). -
Construct
metrics = nnx.MultiMetric(mse=nnx.metrics.Average('mse'), mae=nnx.metrics.Average('mae')). -
Slice
x_batches_flatintonum_batches = N // batch_sizechunks of sizebatch_size. For each chunk, compute MSE and MAE on the model’s prediction and callmetrics.update(mse=..., mae=...). -
result = metrics.compute(). Returnjnp.array([float(result["mse"]), float(result["mae"])]).
Inputs:
-
seed: float (cast to int). -
x_batches_flat: 2-D(N, D_in). Slice into batches. -
y_batches_flat: 2-D(N, D_out). -
batch_size: float (cast to int). Divides N exactly. -
lr: float.
Output: 1-D (2,) — [avg_mse, avg_mae].
Hints
Sign in to attempt this problem and view the solution.