We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
hard
primitives
Lookahead Optimizer Wrapper
Why this matters
Lookahead (Zhang et al. 2019) is a meta-optimizer that wraps any inner optimizer with a two-tier weight scheme. It consistently improves convergence for Adam and SGD on both vision and NLP tasks, and is widely used as a drop-in wrapper.
How it works
Lookahead maintains two sets of weights:
- Fast weights โ updated every step by the inner optimizer.
-
Slow weights โ updated every
sync_periodsteps via interpolation:slow โ slow + slow_step_size * (fast - slow)
On non-sync steps, fast weights move normally and slow weights are frozen. After sync, both fast and slow are reset to the interpolated value.
Using it in Optax
fast_optimizer = optax.sgd(lr)
optimizer = optax.lookahead(fast_optimizer,
sync_period=sync_period,
slow_step_size=slow_step_size)
lookahead_params = optax.LookaheadParams(fast=params, slow=params)
opt_state = optimizer.init(lookahead_params)
updates, _ = optimizer.update(grads, opt_state, lookahead_params)
new_params = optax.apply_updates(lookahead_params, updates)
return new_params.fast
Common pitfalls
-
You must wrap params in
LookaheadParams(fast=..., slow=...). -
The return value is
new_params.fast(the fast-weight branch). - On a single step (sync_period > 1), slow weights do not change.
Inputs
-
params: 1-D JAX array (initial fast and slow weights, both equal). -
grads: 1-D JAX array of gradients (same shape). -
lr: inner SGD learning rate. -
slow_step_size: interpolation step for slow weights at sync. -
sync_period: steps between slow-weight syncs (float, cast to int).
Output
1-D array โ updated fast weights after one optimizer step.
Hints
optax
lookahead
Sign in to attempt this problem and view the solution.