hard primitives

Lookahead Optimizer Wrapper

Why this matters

Lookahead (Zhang et al. 2019) is a meta-optimizer that wraps any inner optimizer with a two-tier weight scheme. It consistently improves convergence for Adam and SGD on both vision and NLP tasks, and is widely used as a drop-in wrapper.

How it works

Lookahead maintains two sets of weights:

  • Fast weights โ€” updated every step by the inner optimizer.
  • Slow weights โ€” updated every sync_period steps via interpolation: slow โ† slow + slow_step_size * (fast - slow)

On non-sync steps, fast weights move normally and slow weights are frozen. After sync, both fast and slow are reset to the interpolated value.

Using it in Optax

fast_optimizer = optax.sgd(lr)
optimizer = optax.lookahead(fast_optimizer,
                            sync_period=sync_period,
                            slow_step_size=slow_step_size)
lookahead_params = optax.LookaheadParams(fast=params, slow=params)
opt_state = optimizer.init(lookahead_params)
updates, _ = optimizer.update(grads, opt_state, lookahead_params)
new_params = optax.apply_updates(lookahead_params, updates)
return new_params.fast

Common pitfalls

  • You must wrap params in LookaheadParams(fast=..., slow=...).
  • The return value is new_params.fast (the fast-weight branch).
  • On a single step (sync_period > 1), slow weights do not change.

Inputs

  • params: 1-D JAX array (initial fast and slow weights, both equal).
  • grads: 1-D JAX array of gradients (same shape).
  • lr: inner SGD learning rate.
  • slow_step_size: interpolation step for slow weights at sync.
  • sync_period: steps between slow-weight syncs (float, cast to int).

Output

1-D array โ€” updated fast weights after one optimizer step.

Hints

optax lookahead

Sign in to attempt this problem and view the solution.