hard primitives

Microbatched Gradient Accumulation via scan

Why this matters

When training on large batches that exceed accelerator memory, the standard technique is gradient accumulation: split the full batch into smaller microbatches, compute gradients for each microbatch, then sum them before the optimizer step. The result is mathematically identical to computing one large gradient (modulo floating-point order effects), but the peak memory footprint is proportional to the microbatch size, not the full batch.

In JAX, lax.scan is the idiomatic tool for this pattern:

  • scan compiles a loop body once and executes it N times β€” no Python overhead, JIT-friendly, XLA-compilable.
  • The accumulator carries the running gradient sum.
  • Each step computes one microbatch gradient and adds it to the accumulator.

This pattern is ubiquitous in production JAX training loops (e.g., Haiku, Flax, Optax examples) wherever the full-batch gradient would OOM.

Worked mini-example

import jax
import jax.numpy as jnp
from jax import lax

x_full = jnp.array([1.0, 2.0, 3.0, 4.0])
y_full = jnp.array([0.0, 0.0, 0.0, 0.0])
w = 1.0
n_micro = 2

x_micro = x_full.reshape(n_micro, -1)   # shape (2, 2)
y_micro = y_full.reshape(n_micro, -1)   # shape (2, 2)

def micro_loss(w, x, y):
    return jnp.sum((w * x - y) ** 2)

grad_fn = jax.grad(micro_loss, argnums=0)

def step(acc, micro):
    x_chunk, y_chunk = micro
    return acc + grad_fn(w, x_chunk, y_chunk), None

final_acc, _ = lax.scan(step, jnp.zeros(()), (x_micro, y_micro))
# final_acc β†’ 60.0   (= grad of full-batch loss w.r.t. w)

Common pitfalls

  • Reshape before scanning β€” lax.scan maps over the leading axis of each array. Reshape x_full from (N,) to (n_micro, micro_size) so that scan steps through microbatches.
  • Accumulator shape must match the gradient shape β€” for a scalar w, the gradient is also a scalar, so initialize with jnp.zeros(()), not jnp.zeros(1).
  • step returns (new_carry, output) β€” when you don’t need the per-step output, return None as the second element.
  • n_microbatches arrives as a float β€” convert to int with int(n_microbatches) before using it in reshape.
  • Result equals the full-batch gradient β€” summing microbatch gradients is linearly equivalent to differentiating the summed losses.

Problem

Implement microbatched_grads(w, b, x_full, y_full, n_microbatches) that:

  1. Splits x_full and y_full into n_microbatches equal chunks.
  2. Uses lax.scan to accumulate the gradient of sum((w*x - y)^2) w.r.t. w across all microbatches.
  3. Returns the accumulated scalar gradient.
  • w: scalar β€” the weight.
  • b: scalar β€” unused, kept for signature parity.
  • x_full, y_full: 1-D jax arrays of shape (N,) where N is divisible by n_microbatches.
  • n_microbatches: float (convert to int internally).

Returns: scalar β€” sum of per-microbatch gradients w.r.t. w.

Hints

jax scan grad-accumulation

Sign in to attempt this problem and view the solution.