We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NNX LR Schedule
Why this matters
The warmup-then-cosine-decay schedule is the standard for training large language models (BERT, GPT-3, T5, LLaMA, Qwen). Starting from zero and linearly warming up prevents early instability when the optimizer’s momentum buffers are empty; cosine decay then smoothly reduces the LR to near-zero by the end of training.
Schedules in JAX are just step -> lr functions. They have nothing
framework-specific about them — optax.warmup_cosine_decay_schedule
works the same whether you’re using nnx, Linen, or no framework at
all. This problem is intentionally pure-JAX/optax to make that point:
LR scheduling is one place nnx makes no difference, because there’s
no model state to manage.
The two phases
optax.warmup_cosine_decay_schedule composes two sub-schedules:
-
Warmup (steps 0 to
warmup_steps): linear ramp frominit_value(0.0 here) topeak_value. -
Cosine decay (steps
warmup_stepstowarmup_steps + decay_steps): cosine annealing frompeak_valuetoend_value(0.0 here).
step=0 → 0.0
step=warmup_steps → peak_value
step=warmup_steps+decay_steps → 0.0 (end of decay)
After warmup_steps + decay_steps the LR clamps at end_value.
Total steps vs decay steps
The argument passed to decay_steps is the length of the cosine
portion only, not the total training length. If you want the
schedule to span total_steps and warmup is warmup_steps, then:
decay_steps = total_steps - warmup_steps
This problem takes total_steps and computes decay_steps from it.
How it’s used in a training loop
Compose it with the optimizer:
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=peak_lr,
warmup_steps=warmup_steps,
decay_steps=total_steps - warmup_steps, end_value=0.0,
)
optimizer = nnx.Optimizer(model, optax.sgd(schedule), wrt=nnx.Param)
optax.sgd accepts either a scalar lr or a schedule function — when
given a schedule, it reads its own step counter on every update
and queries the schedule. This problem doesn’t go that far; we just
evaluate the schedule at a single step.
Common pitfalls
-
Passing
total_stepstodecay_steps. Then the cosine phase is too long and the schedule doesn’t reachend_valueby the end of training. -
Forgetting to cast
stepto int. The schedule expects an int. -
Treating warmup as part of decay. Cosine starts AFTER warmup;
step 0 always returns
init_value.
Problem
Implement lr_at_step(step, peak_lr, warmup_steps, total_steps):
-
Build
schedule = optax.warmup_cosine_decay_schedule(init_value=0.0, peak_value=peak_lr, warmup_steps=int(warmup_steps), decay_steps=int(total_steps) - int(warmup_steps), end_value=0.0). -
Evaluate at
int(step). Return as 1-D(1,)arrayjnp.array([float(schedule(int(step)))]).
Pure JAX/optax — no nnx model needed.
Inputs:
-
step: float (cast to int). -
peak_lr: float — the peak LR reached at end of warmup. -
warmup_steps: float (cast to int). -
total_steps: float (cast to int) —warmup_steps + decay_steps.
Output: 1-D (1,) — [lr_at_step].
Hints
Sign in to attempt this problem and view the solution.