We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
hard
primitives
Gradient Accumulation via MultiSteps
Why this matters
When a mini-batch is too large to fit in GPU memory, gradient accumulation
lets you split it across multiple forward-backward passes and average the
gradients before applying the optimizer. optax.MultiSteps wraps any
inner optimizer to do exactly this โ accumulating over k steps before
making one real parameter update.
The recipe
optimizer = optax.MultiSteps(optax.sgd(lr), every_k_schedule=2)
opt_state = optimizer.init(params)
# Step 1 โ accumulates grads_a, returns zero updates
updates_a, opt_state = optimizer.update(grads_a, opt_state, params)
params = optax.apply_updates(params, updates_a) # no-op (updates_a = 0)
# Step 2 โ averages (grads_a + grads_b) / 2, applies inner optimizer
updates_b, _ = optimizer.update(grads_b, opt_state, params)
return optax.apply_updates(params, updates_b)
Common pitfalls
-
Actual parameter updates only happen every
ksteps; between sync points,update()returns zero updates soapply_updatesis a no-op. - The averaged gradient (not the sum) is passed to the inner optimizer.
-
every_k_schedulecan be a callable for dynamic schedules โ here we use a fixed integer.
Inputs
-
params: 1-D JAX array โ initial parameters. -
grads_a: 1-D JAX array โ gradients from micro-batch 1. -
grads_b: 1-D JAX array โ gradients from micro-batch 2. -
lr: scalar โ SGD learning rate applied at the sync step.
Output
1-D array โ params after BOTH micro-steps via MultiSteps(k=2).
Hints
optax
multi-steps
grad-accumulation
Sign in to attempt this problem and view the solution.