We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Warmup-Cosine LR at Step
Why this matters
A constant learning rate is a relic. Every modern training run uses
a schedule β a function from step to learning_rate.
The most common shape, used by GPT-3, ViT, BERT, LLaMA, and roughly every paper since 2018, is warmup-cosine:
-
Warmup for
warmup_steps(linear ramp from 0 topeak_lr). Stops large early gradients from blowing up the optimizer. -
Cosine decay for the rest (
total_steps - warmup_steps). Frompeak_lrsmoothly down to 0 (orend_value).
Visually:
peak β
β β±β²___
β β± β²___
β β± β²___
0 β΄βββββ±βββββββββββββββββ²βββββΊ
β β
warmup_steps total_steps
In Optax
Optaxβs warmup_cosine_decay_schedule returns a function step β lr:
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=peak_lr,
warmup_steps=warmup_steps,
decay_steps=total_steps - warmup_steps,
end_value=0.0,
)
lr_at_500 = schedule(500) # call as a function
Important: decay_steps is the cosine decay LENGTH, not the absolute
end step. If your training is T steps total and you warm up for W,
decay_steps = T - W.
Why call them callables?
A schedule isnβt a chunk of state β itβs a pure function:
step β¦ lr. That makes it trivially composable: you can build new
schedules by combining existing ones (optax.join_schedules,
optax.linear_schedule, etc.) without touching any optimizer state.
The schedule plugs into optax.adam(schedule) or optax.sgd(schedule)
directly: optax internally calls schedule(state.step) each update.
Worked numbers
With peak_lr = 0.001, warmup_steps = 10, total_steps = 100:
-
step = 0:0.0(start of warmup) -
step = 5:0.0005(halfway through warmup) -
step = 10:0.001(peak; warmup just completed) -
step = 55: β0.0005(halfway through cosine) -
step = 100:0.0(end of cosine)
Common pitfalls
-
decay_stepsconfused with end step: itβs the LENGTH of the cosine ramp, nottotal_steps. -
Forgetting
init_value=0.0: warmupβs start. Default is 0 but some schedule constructors require explicit values. -
Calling at
step > total_steps: cosine decays toend_valueand stays there; not an error, just clamped. -
Treating the schedule as state: itβs just a function.
schedule(7)andschedule(7)always return the same value.
Problem
Build a warmup_cosine_decay_schedule and evaluate it at step:
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=peak_lr,
warmup_steps=warmup_steps,
decay_steps=total_steps - warmup_steps,
end_value=0.0,
)
lr = schedule(step)
Return as 1-D (1,) array.
Inputs:
-
step: float (cast to int) β current global step. -
peak_lr: float β peak learning rate. -
warmup_steps: float (cast to int). -
total_steps: float (cast to int).
Output: 1-D (1,) β [lr].
Hints
Sign in to attempt this problem and view the solution.