hard primitives

Gradient Accumulation Step

Why this matters

“Effective batch size” is what statistics actually care about — small batches give noisy gradients, large batches give better gradient estimates and (usually) better convergence. But large batches OOM the GPU.

Gradient accumulation lets you decouple the two: split a big “effective batch” of size N × M into N micro-batches of size M, each fitting in memory. Run N forward+backward passes, sum the grads, divide by N, take ONE optimizer step. The optimizer sees the same expected gradient as if you’d run a single batch of size N × M.

This is how every modern LLM is pre-trained on hardware that can’t fit the natural batch in memory — accumulation steps in the 8-128 range are common.

The recipe (2 micro-batches, A and B)

grads_a = jax.grad(loss_fn)(params, x_a, y_a)
grads_b = jax.grad(loss_fn)(params, x_b, y_b)

grads = jax.tree_util.tree_map(
    lambda a, b: (a + b) / 2.0,
    grads_a,
    grads_b,
)

state = state.apply_gradients(grads=grads)

Why average and not sum? If loss_fn already takes a mean over its batch, then grad(mean(loss_a)) is 1/M_a * sum(grads_per_example). Adding two such grads and dividing by 2 gives 1/2 * (mean(grads_A) + mean(grads_B)), which equals the global mean when M_A = M_B. So average is correct here. (If micro-batches have different sizes, weight by size.)

Why divide AT THE END, not as you go

You COULD scale each batch’s grads by 1/N first and just add. The end result is the same:

(g_a + g_b) / 2  ==  g_a/2 + g_b/2

But end-divide stays exact in fp32 while per-step divide loses some precision (each is now smaller, more bits get rounded off). The end- divide form also generalizes naturally to grads computed from N micro-batches with a single line.

Why a tree_map and not just +?

grads is a pytree (a nested dict of arrays — one entry per param). Plain + doesn’t traverse pytrees:

grads_a + grads_b   # TypeError: unsupported between dicts

jax.tree_util.tree_map(f, t1, t2) walks both trees in lockstep and applies f leaf-wise. Same result as +-ing aligned arrays, but safe for arbitrary nesting.

When NOT to do this

  • Models with BatchNorm: each micro-batch has its own (small, noisy) batch statistics; accumulating doesn’t fix that. Use nn.GroupNorm or nn.LayerNorm instead, or “synchronized” BN.
  • Models with dropout: PyTorch+JAX dropout is per-sample, so it’s fine — but be careful about RNG splitting if you use jax.random.PRNGKey.

Problem

  1. Build the same TrainState as in pos 59 (tiny Dense(1) + sgd(lr)).
  2. Compute grads_a = jax.grad(loss_fn)(params, x_batches[0], y_batches[0]).
  3. Compute grads_b similarly on micro-batch 1.
  4. Average via jax.tree_util.tree_map(lambda a, b: (a + b) / 2.0, grads_a, grads_b).
  5. state = state.apply_gradients(grads=grads).
  6. Compute the post-step MSE on the concatenated full batch and return as 1-D (1,).

x_batches arrives shaped (2, M, D) and y_batches shaped (2, M). Both micro-batches have the same M.

Inputs:

  • seed: float (cast to int).
  • x_batches: 3-D (2, M, D).
  • y_batches: 2-D (2, M).
  • lr: float.

Output: 1-D (1,)[final_loss_on_full_batch_after_step].

Hints

flax training grad-accumulation

Sign in to attempt this problem and view the solution.