hard primitives

Sharded Eval Loss

Why this matters

A real eval set is often big enough that running it as a single batch OOMs the device — billions of tokens, millions of images. The fix: chunk the batch into shards, eval each shard, average.

This is also the conceptual core of pmap-based distributed eval: each device gets one shard, runs forward, you psum and divide. The single-host loop version below builds the right intuition.

The pattern

x_shards = jnp.split(x, num_shards, axis=0)
y_shards = jnp.split(y, num_shards, axis=0)

total = 0.0
for xs, ys in zip(x_shards, y_shards):
    preds = model.apply(params, xs).reshape(-1)
    shard_loss = jnp.mean((preds - ys) ** 2)
    total = total + shard_loss
avg = total / num_shards

Note: dividing by num_shards only gives the right global mean when shards are equal-sized (which jnp.split enforces — it raises if the batch doesn’t divide evenly). If you instead used jnp.array_split (allows unequal shards), you’d weight each shard’s loss by len(xs) / N to match the un-sharded mean.

Why split, not vmap?

jax.vmap adds a leading “batch” axis but expects a single function body — it works great here too:

@jax.vmap
def shard_loss(xs, ys):
    preds = model.apply(params, xs).reshape(-1)
    return jnp.mean((preds - ys) ** 2)
losses = shard_loss(jnp.stack(x_shards), jnp.stack(y_shards))
avg = losses.mean()

The vmap version compiles to ONE XLA program; the loop version compiles to N. For large num_shards and JIT’d code, vmap is typically faster. For OOM avoidance with HUGE per-shard work, the Python loop is better — it lets each shard finish and free memory before the next starts.

Either approach is fine; the loop version is conceptually closer to the distributed-eval pattern (one shard per device, then psum).

init must NOT see the full batch

Common bug: calling model.init(rng, x) on the full unsharded x consumes memory proportional to the full batch — the very thing we’re trying to avoid! Init on a single example or a single shard:

params = model.init(rng, x[:1])      # init shape doesn't matter much

The init is one forward pass purely for variable allocation; the smallest example that’s representative will do.

Common pitfalls

  • Indivisible batch + jnp.split: jnp.split raises if the batch doesn’t divide evenly. Trim the batch or use array_split.
  • Mean of means != global mean when shards are unequal: weighted sum by len(xs), not unweighted average.
  • Re-initializing params per shard: don’t! Init ONCE; apply with the same params to each shard.
  • JIT’ing the Python loop: usually you want jax.jit on the INSIDE (per-shard call), not on the loop itself, so each shard finishes and frees memory.

Problem

  1. Build a tiny nn.Dense(1) model.
  2. Init params from seed and a single example slice.
  3. Split x: (N, D) and y: (N,) into num_shards along axis 0 (N is divisible by num_shards).
  4. For each shard, compute MSE between Dense(x) and y.
  5. Return the mean of the per-shard losses as a 1-D (1,) array.

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D).
  • y: 1-D (N,).
  • num_shards: float (cast to int) — divides N exactly.

Output: 1-D (1,)[average_loss].

Hints

flax eval sharding

Sign in to attempt this problem and view the solution.