medium primitives

Warmup + Cosine Decay

Why this matters

The warmup-then-cosine-decay schedule is the standard for training large language models (BERT, GPT-3, T5, LLaMA). Starting from zero and linearly warming up prevents early instability, while cosine decay smoothly reduces the LR to near-zero by the end of training.

The two phases

optax.warmup_cosine_decay_schedule composes two sub-schedules:

  1. Warmup (steps 0 to warmup_steps): linear ramp from init_value (default 0.0) to peak_value.
  2. Cosine decay (steps warmup_steps to warmup_steps + decay_steps): cosine annealing from peak_value to end_value (default 0.0).
step=0         โ†’ init_value  (0.0 by default)
step=warmup_steps โ†’ peak_value
step=warmup_steps+decay_steps โ†’ end_value (0.0 by default)

Common pitfalls

  • decay_steps is the length of the cosine portion only, not the total training length. Total steps = warmup_steps + decay_steps.
  • After warmup_steps + decay_steps the LR clamps at end_value.
  • Cast step, warmup_steps, and decay_steps to int.

Inputs

  • step: scalar (cast to int).
  • peak_lr: peak learning rate (reached at end of warmup).
  • warmup_steps: number of linear warmup steps (cast to int).
  • decay_steps: length of cosine decay phase (cast to int).

Output

Scalar โ€” the LR at step.

Hints

optax schedule warmup cosine

Sign in to attempt this problem and view the solution.