We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Microbatched Gradient Accumulation via scan
Why this matters
When training on large batches that exceed accelerator memory, the standard technique is gradient accumulation: split the full batch into smaller microbatches, compute gradients for each microbatch, then sum them before the optimizer step. The result is mathematically identical to computing one large gradient (modulo floating-point order effects), but the peak memory footprint is proportional to the microbatch size, not the full batch.
In JAX, lax.scan is the idiomatic tool for this pattern:
-
scancompiles a loop body once and executes it N times β no Python overhead, JIT-friendly, XLA-compilable. - The accumulator carries the running gradient sum.
- Each step computes one microbatch gradient and adds it to the accumulator.
This pattern is ubiquitous in production JAX training loops (e.g., Haiku, Flax, Optax examples) wherever the full-batch gradient would OOM.
Worked mini-example
import jax
import jax.numpy as jnp
from jax import lax
x_full = jnp.array([1.0, 2.0, 3.0, 4.0])
y_full = jnp.array([0.0, 0.0, 0.0, 0.0])
w = 1.0
n_micro = 2
x_micro = x_full.reshape(n_micro, -1) # shape (2, 2)
y_micro = y_full.reshape(n_micro, -1) # shape (2, 2)
def micro_loss(w, x, y):
return jnp.sum((w * x - y) ** 2)
grad_fn = jax.grad(micro_loss, argnums=0)
def step(acc, micro):
x_chunk, y_chunk = micro
return acc + grad_fn(w, x_chunk, y_chunk), None
final_acc, _ = lax.scan(step, jnp.zeros(()), (x_micro, y_micro))
# final_acc β 60.0 (= grad of full-batch loss w.r.t. w)
Common pitfalls
-
Reshape before scanning β
lax.scanmaps over the leading axis of each array. Reshapex_fullfrom(N,)to(n_micro, micro_size)so that scan steps through microbatches. -
Accumulator shape must match the gradient shape β for a scalar
w, the gradient is also a scalar, so initialize withjnp.zeros(()), notjnp.zeros(1). -
stepreturns(new_carry, output)β when you donβt need the per-step output, returnNoneas the second element. -
n_microbatchesarrives as a float β convert tointwithint(n_microbatches)before using it inreshape. - Result equals the full-batch gradient β summing microbatch gradients is linearly equivalent to differentiating the summed losses.
Problem
Implement microbatched_grads(w, b, x_full, y_full, n_microbatches) that:
-
Splits
x_fullandy_fullinton_microbatchesequal chunks. -
Uses
lax.scanto accumulate the gradient ofsum((w*x - y)^2)w.r.t.wacross all microbatches. - Returns the accumulated scalar gradient.
-
w: scalar β the weight. -
b: scalar β unused, kept for signature parity. -
x_full,y_full: 1-D jax arrays of shape(N,)whereNis divisible byn_microbatches. -
n_microbatches: float (convert to int internally).
Returns: scalar β sum of per-microbatch gradients w.r.t. w.
Hints
Sign in to attempt this problem and view the solution.