We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Gradient Accumulation Step
Why this matters
“Effective batch size” is what statistics actually care about — small batches give noisy gradients, large batches give better gradient estimates and (usually) better convergence. But large batches OOM the GPU.
Gradient accumulation lets you decouple the two: split a big
“effective batch” of size N × M into N micro-batches of size M,
each fitting in memory. Run N forward+backward passes, sum the
grads, divide by N, take ONE optimizer step. The optimizer sees the
same expected gradient as if you’d run a single batch of size N × M.
This is how every modern LLM is pre-trained on hardware that can’t fit the natural batch in memory — accumulation steps in the 8-128 range are common.
The recipe (2 micro-batches, A and B)
grads_a = jax.grad(loss_fn)(params, x_a, y_a)
grads_b = jax.grad(loss_fn)(params, x_b, y_b)
grads = jax.tree_util.tree_map(
lambda a, b: (a + b) / 2.0,
grads_a,
grads_b,
)
state = state.apply_gradients(grads=grads)
Why average and not sum? If loss_fn already takes a mean over its
batch, then grad(mean(loss_a)) is 1/M_a * sum(grads_per_example).
Adding two such grads and dividing by 2 gives
1/2 * (mean(grads_A) + mean(grads_B)), which equals the global mean
when M_A = M_B. So average is correct here. (If micro-batches have
different sizes, weight by size.)
Why divide AT THE END, not as you go
You COULD scale each batch’s grads by 1/N first and just add. The
end result is the same:
(g_a + g_b) / 2 == g_a/2 + g_b/2
But end-divide stays exact in fp32 while per-step divide loses some precision (each is now smaller, more bits get rounded off). The end- divide form also generalizes naturally to grads computed from N micro-batches with a single line.
Why a tree_map and not just +?
grads is a pytree (a nested dict of arrays — one entry per param).
Plain + doesn’t traverse pytrees:
grads_a + grads_b # TypeError: unsupported between dicts
jax.tree_util.tree_map(f, t1, t2) walks both trees in lockstep and
applies f leaf-wise. Same result as +-ing aligned arrays, but
safe for arbitrary nesting.
When NOT to do this
-
Models with BatchNorm: each micro-batch has its own (small,
noisy) batch statistics; accumulating doesn’t fix that. Use
nn.GroupNormornn.LayerNorminstead, or “synchronized” BN. -
Models with dropout: PyTorch+JAX dropout is per-sample, so it’s
fine — but be careful about RNG splitting if you use
jax.random.PRNGKey.
Problem
-
Build the same
TrainStateas in pos 59 (tinyDense(1)+ sgd(lr)). -
Compute
grads_a = jax.grad(loss_fn)(params, x_batches[0], y_batches[0]). -
Compute
grads_bsimilarly on micro-batch 1. -
Average via
jax.tree_util.tree_map(lambda a, b: (a + b) / 2.0, grads_a, grads_b). -
state = state.apply_gradients(grads=grads). -
Compute the post-step MSE on the concatenated full batch and
return as 1-D
(1,).
x_batches arrives shaped (2, M, D) and y_batches shaped (2, M).
Both micro-batches have the same M.
Inputs:
-
seed: float (cast to int). -
x_batches: 3-D(2, M, D). -
y_batches: 2-D(2, M). -
lr: float.
Output: 1-D (1,) — [final_loss_on_full_batch_after_step].
Hints
Sign in to attempt this problem and view the solution.