hard primitives

Gradient Accumulation via MultiSteps

Why this matters

When a mini-batch is too large to fit in GPU memory, gradient accumulation lets you split it across multiple forward-backward passes and average the gradients before applying the optimizer. optax.MultiSteps wraps any inner optimizer to do exactly this โ€” accumulating over k steps before making one real parameter update.

The recipe

optimizer = optax.MultiSteps(optax.sgd(lr), every_k_schedule=2)
opt_state = optimizer.init(params)

# Step 1 โ€” accumulates grads_a, returns zero updates
updates_a, opt_state = optimizer.update(grads_a, opt_state, params)
params = optax.apply_updates(params, updates_a)   # no-op (updates_a = 0)

# Step 2 โ€” averages (grads_a + grads_b) / 2, applies inner optimizer
updates_b, _ = optimizer.update(grads_b, opt_state, params)
return optax.apply_updates(params, updates_b)

Common pitfalls

  • Actual parameter updates only happen every k steps; between sync points, update() returns zero updates so apply_updates is a no-op.
  • The averaged gradient (not the sum) is passed to the inner optimizer.
  • every_k_schedule can be a callable for dynamic schedules โ€” here we use a fixed integer.

Inputs

  • params: 1-D JAX array โ€” initial parameters.
  • grads_a: 1-D JAX array โ€” gradients from micro-batch 1.
  • grads_b: 1-D JAX array โ€” gradients from micro-batch 2.
  • lr: scalar โ€” SGD learning rate applied at the sync step.

Output

1-D array โ€” params after BOTH micro-steps via MultiSteps(k=2).

Hints

optax multi-steps grad-accumulation

Sign in to attempt this problem and view the solution.