We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Sharded Eval Loss
Why this matters
A real eval set is often big enough that running it as a single batch OOMs the device — billions of tokens, millions of images. The fix: chunk the batch into shards, eval each shard, average.
This is also the conceptual core of pmap-based distributed eval:
each device gets one shard, runs forward, you psum and divide. The
single-host loop version below builds the right intuition.
The pattern
x_shards = jnp.split(x, num_shards, axis=0)
y_shards = jnp.split(y, num_shards, axis=0)
total = 0.0
for xs, ys in zip(x_shards, y_shards):
preds = model.apply(params, xs).reshape(-1)
shard_loss = jnp.mean((preds - ys) ** 2)
total = total + shard_loss
avg = total / num_shards
Note: dividing by num_shards only gives the right global mean
when shards are equal-sized (which jnp.split enforces — it
raises if the batch doesn’t divide evenly). If you instead used
jnp.array_split (allows unequal shards), you’d weight each
shard’s loss by len(xs) / N to match the un-sharded mean.
Why split, not vmap?
jax.vmap adds a leading “batch” axis but expects a single function
body — it works great here too:
@jax.vmap
def shard_loss(xs, ys):
preds = model.apply(params, xs).reshape(-1)
return jnp.mean((preds - ys) ** 2)
losses = shard_loss(jnp.stack(x_shards), jnp.stack(y_shards))
avg = losses.mean()
The vmap version compiles to ONE XLA program; the loop version
compiles to N. For large num_shards and JIT’d code, vmap is
typically faster. For OOM avoidance with HUGE per-shard work, the
Python loop is better — it lets each shard finish and free memory
before the next starts.
Either approach is fine; the loop version is conceptually closer to
the distributed-eval pattern (one shard per device, then psum).
init must NOT see the full batch
Common bug: calling model.init(rng, x) on the full unsharded x
consumes memory proportional to the full batch — the very thing
we’re trying to avoid! Init on a single example or a single shard:
params = model.init(rng, x[:1]) # init shape doesn't matter much
The init is one forward pass purely for variable allocation; the smallest example that’s representative will do.
Common pitfalls
-
Indivisible batch +
jnp.split:jnp.splitraises if the batch doesn’t divide evenly. Trim the batch or usearray_split. -
Mean of means != global mean when shards are unequal: weighted
sum by
len(xs), not unweighted average. - Re-initializing params per shard: don’t! Init ONCE; apply with the same params to each shard.
-
JIT’ing the Python loop: usually you want
jax.jiton the INSIDE (per-shard call), not on the loop itself, so each shard finishes and frees memory.
Problem
-
Build a tiny
nn.Dense(1)model. -
Init params from
seedand a single example slice. -
Split
x: (N, D)andy: (N,)intonum_shardsalong axis 0 (Nis divisible bynum_shards). -
For each shard, compute MSE between
Dense(x)andy. -
Return the mean of the per-shard losses as a 1-D
(1,)array.
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D). -
y: 1-D(N,). -
num_shards: float (cast to int) — dividesNexactly.
Output: 1-D (1,) — [average_loss].
Hints
Sign in to attempt this problem and view the solution.