medium primitives

Warmup-Cosine LR at Step

Why this matters

A constant learning rate is a relic. Every modern training run uses a schedule β€” a function from step to learning_rate.

The most common shape, used by GPT-3, ViT, BERT, LLaMA, and roughly every paper since 2018, is warmup-cosine:

  1. Warmup for warmup_steps (linear ramp from 0 to peak_lr). Stops large early gradients from blowing up the optimizer.
  2. Cosine decay for the rest (total_steps - warmup_steps). From peak_lr smoothly down to 0 (or end_value).

Visually:

                peak ┐
                     β”‚       β•±β•²___
                     β”‚      β•±     β•²___
                     β”‚     β•±          β•²___
                   0 ┴────╱────────────────╲────►
                          ↑                ↑
                      warmup_steps      total_steps

In Optax

Optax’s warmup_cosine_decay_schedule returns a function step β†’ lr:

schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=peak_lr,
    warmup_steps=warmup_steps,
    decay_steps=total_steps - warmup_steps,
    end_value=0.0,
)
lr_at_500 = schedule(500)   # call as a function

Important: decay_steps is the cosine decay LENGTH, not the absolute end step. If your training is T steps total and you warm up for W, decay_steps = T - W.

Why call them callables?

A schedule isn’t a chunk of state β€” it’s a pure function: step ↦ lr. That makes it trivially composable: you can build new schedules by combining existing ones (optax.join_schedules, optax.linear_schedule, etc.) without touching any optimizer state.

The schedule plugs into optax.adam(schedule) or optax.sgd(schedule) directly: optax internally calls schedule(state.step) each update.

Worked numbers

With peak_lr = 0.001, warmup_steps = 10, total_steps = 100:

  • step = 0: 0.0 (start of warmup)
  • step = 5: 0.0005 (halfway through warmup)
  • step = 10: 0.001 (peak; warmup just completed)
  • step = 55: β‰ˆ 0.0005 (halfway through cosine)
  • step = 100: 0.0 (end of cosine)

Common pitfalls

  • decay_steps confused with end step: it’s the LENGTH of the cosine ramp, not total_steps.
  • Forgetting init_value=0.0: warmup’s start. Default is 0 but some schedule constructors require explicit values.
  • Calling at step > total_steps: cosine decays to end_value and stays there; not an error, just clamped.
  • Treating the schedule as state: it’s just a function. schedule(7) and schedule(7) always return the same value.

Problem

Build a warmup_cosine_decay_schedule and evaluate it at step:

schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=peak_lr,
    warmup_steps=warmup_steps,
    decay_steps=total_steps - warmup_steps,
    end_value=0.0,
)
lr = schedule(step)

Return as 1-D (1,) array.

Inputs:

  • step: float (cast to int) β€” current global step.
  • peak_lr: float β€” peak learning rate.
  • warmup_steps: float (cast to int).
  • total_steps: float (cast to int).

Output: 1-D (1,) β€” [lr].

Hints

jax optax schedule training

Sign in to attempt this problem and view the solution.