hard primitives

NNX Gradient Accumulation

Why this matters

“Effective batch size” is what statistics actually care about — small batches give noisy gradients, large batches give better gradient estimates. But large batches OOM the GPU.

Gradient accumulation lets you decouple the two: split a big “effective batch” of size N × M into N micro-batches of size M, each fitting in memory. Run N forward+backward passes, average the grads, take ONE optimizer step. The optimizer sees the same expected gradient as if you’d run a single batch of size N × M.

Every modern LLM is pre-trained this way; accumulation steps in the 8-128 range are common.

The recipe (2 micro-batches, A and B)

_, grads_a = nnx.value_and_grad(loss_fn)(model, x_a, y_a)
_, grads_b = nnx.value_and_grad(loss_fn)(model, x_b, y_b)

avg_grads = jax.tree_util.tree_map(
    lambda a, b: (a + b) / 2.0,
    grads_a, grads_b,
)

optimizer.update(model, avg_grads)   # ONE optimizer step

Why average? If loss_fn already takes a mean over its batch, grad(mean(loss_a)) is 1/M_a * sum(grads_per_example). Adding two such grads and dividing by 2 gives 1/2 * (mean(grads_A) + mean(grads_B)), which equals the global mean when the micro-batches are the same size.

Why a tree_map and not just +?

nnx grads are a nested pytree (dict-of-arrays — one entry per Param). Plain + doesn’t traverse pytrees:

grads_a + grads_b   # TypeError: unsupported between dicts

jax.tree_util.tree_map(f, t1, t2) walks both trees in lockstep and applies f leaf-wise. Same shape as the model’s params; same layout optax expects.

In Linen vs nnx

The Linen flow ends with state = state.apply_gradients(grads=avg_grads), where state is a frozen TrainState and you must rebind. In nnx the last step is optimizer.update(model, avg_grads) — no rebind, the model’s params are mutated in place. The middle of the recipe (the two value_and_grad calls and the tree_map) is identical between the two frameworks.

Why divide AT THE END, not as you go

You COULD scale each batch’s grads by 1/N first and add. The end result is mathematically the same:

(g_a + g_b) / 2  ==  g_a/2 + g_b/2

But end-divide stays exact in fp32 while per-step divide loses some precision (each is now smaller, more bits get rounded off). The end-divide form also generalizes naturally to N micro-batches.

When NOT to do this

  • Models with BatchNorm: each micro-batch has its own (small, noisy) batch statistics; accumulating doesn’t fix that. Use LayerNorm/GroupNorm or sync-BN.
  • Dropout: usually fine because dropout is per-sample, but be careful about RNG splitting between micro-batches.

Problem

Implement grad_accumulation_step(seed, x_a, y_a, x_b, y_b, lr):

  1. Build model = nnx.Linear(x_a.shape[-1], y_a.shape[-1], rngs=...) and optimizer = nnx.Optimizer(model, optax.sgd(lr), wrt=nnx.Param).
  2. Compute _, grads_a = nnx.value_and_grad(loss_fn)(model, x_a, y_a).
  3. Compute _, grads_b = nnx.value_and_grad(loss_fn)(model, x_b, y_b).
  4. Average via jax.tree_util.tree_map(lambda a, b: (a + b) / 2.0, grads_a, grads_b).
  5. Single update: optimizer.update(model, avg_grads).
  6. Compute MSE on the concatenated full batch (jnp.concatenate([x_a, x_b], axis=0), similarly for y).
  7. Return jnp.array([float(final_loss)]) as 1-D (1,).

Both micro-batches have the same M.

Inputs:

  • seed: float (cast to int).
  • x_a, x_b: 2-D (M, D_in) each.
  • y_a, y_b: 2-D (M, D_out) each.
  • lr: float.

Output: 1-D (1,)[final_loss_on_full_batch_after_step].

Hints

flax nnx training grad-accumulation

Sign in to attempt this problem and view the solution.