We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX Gradient Accumulation
Why this matters
“Effective batch size” is what statistics actually care about — small batches give noisy gradients, large batches give better gradient estimates. But large batches OOM the GPU.
Gradient accumulation lets you decouple the two: split a big
“effective batch” of size N × M into N micro-batches of size M,
each fitting in memory. Run N forward+backward passes, average
the grads, take ONE optimizer step. The optimizer sees the same
expected gradient as if you’d run a single batch of size N × M.
Every modern LLM is pre-trained this way; accumulation steps in the 8-128 range are common.
The recipe (2 micro-batches, A and B)
_, grads_a = nnx.value_and_grad(loss_fn)(model, x_a, y_a)
_, grads_b = nnx.value_and_grad(loss_fn)(model, x_b, y_b)
avg_grads = jax.tree_util.tree_map(
lambda a, b: (a + b) / 2.0,
grads_a, grads_b,
)
optimizer.update(model, avg_grads) # ONE optimizer step
Why average? If loss_fn already takes a mean over its batch,
grad(mean(loss_a)) is 1/M_a * sum(grads_per_example). Adding two
such grads and dividing by 2 gives 1/2 * (mean(grads_A) + mean(grads_B)),
which equals the global mean when the micro-batches are the same size.
Why a tree_map and not just +?
nnx grads are a nested pytree (dict-of-arrays — one entry per Param).
Plain + doesn’t traverse pytrees:
grads_a + grads_b # TypeError: unsupported between dicts
jax.tree_util.tree_map(f, t1, t2) walks both trees in lockstep and
applies f leaf-wise. Same shape as the model’s params; same layout
optax expects.
In Linen vs nnx
The Linen flow ends with state = state.apply_gradients(grads=avg_grads),
where state is a frozen TrainState and you must rebind. In nnx the
last step is optimizer.update(model, avg_grads) — no rebind, the
model’s params are mutated in place. The middle of the recipe (the
two value_and_grad calls and the tree_map) is identical between
the two frameworks.
Why divide AT THE END, not as you go
You COULD scale each batch’s grads by 1/N first and add. The
end result is mathematically the same:
(g_a + g_b) / 2 == g_a/2 + g_b/2
But end-divide stays exact in fp32 while per-step divide loses some precision (each is now smaller, more bits get rounded off). The end-divide form also generalizes naturally to N micro-batches.
When NOT to do this
- Models with BatchNorm: each micro-batch has its own (small, noisy) batch statistics; accumulating doesn’t fix that. Use LayerNorm/GroupNorm or sync-BN.
- Dropout: usually fine because dropout is per-sample, but be careful about RNG splitting between micro-batches.
Problem
Implement grad_accumulation_step(seed, x_a, y_a, x_b, y_b, lr):
-
Build
model = nnx.Linear(x_a.shape[-1], y_a.shape[-1], rngs=...)andoptimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param). -
Compute
_, grads_a = nnx.value_and_grad(loss_fn)(model, x_a, y_a). -
Compute
_, grads_b = nnx.value_and_grad(loss_fn)(model, x_b, y_b). -
Average via
jax.tree_util.tree_map(lambda a, b: (a + b) / 2.0, grads_a, grads_b). -
Single update:
optimizer.update(model, avg_grads). -
Compute MSE on the concatenated full batch
(
jnp.concatenate([x_a, x_b], axis=0), similarly fory). -
Return
jnp.array([float(final_loss)])as 1-D(1,).
Both micro-batches have the same M.
Inputs:
-
seed: float (cast to int). -
x_a,x_b: 2-D(M, D_in)each. -
y_a,y_b: 2-D(M, D_out)each. -
lr: float.
Output: 1-D (1,) — [final_loss_on_full_batch_after_step].
Hints
Sign in to attempt this problem and view the solution.