medium primitives

NNX LR Schedule

Why this matters

The warmup-then-cosine-decay schedule is the standard for training large language models (BERT, GPT-3, T5, LLaMA, Qwen). Starting from zero and linearly warming up prevents early instability when the optimizer’s momentum buffers are empty; cosine decay then smoothly reduces the LR to near-zero by the end of training.

Schedules in JAX are just step -> lr functions. They have nothing framework-specific about them — optax.warmup_cosine_decay_schedule works the same whether you’re using nnx, Linen, or no framework at all. This problem is intentionally pure-JAX/optax to make that point: LR scheduling is one place nnx makes no difference, because there’s no model state to manage.

The two phases

optax.warmup_cosine_decay_schedule composes two sub-schedules:

  1. Warmup (steps 0 to warmup_steps): linear ramp from init_value (0.0 here) to peak_value.
  2. Cosine decay (steps warmup_steps to warmup_steps + decay_steps): cosine annealing from peak_value to end_value (0.0 here).
step=0                                → 0.0
step=warmup_steps                     → peak_value
step=warmup_steps+decay_steps         → 0.0 (end of decay)

After warmup_steps + decay_steps the LR clamps at end_value.

Total steps vs decay steps

The argument passed to decay_steps is the length of the cosine portion only, not the total training length. If you want the schedule to span total_steps and warmup is warmup_steps, then:

decay_steps = total_steps - warmup_steps

This problem takes total_steps and computes decay_steps from it.

How it’s used in a training loop

Compose it with the optimizer:

schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0, peak_value=peak_lr,
    warmup_steps=warmup_steps,
    decay_steps=total_steps - warmup_steps, end_value=0.0,
)
optimizer = nnx.Optimizer(model, optax.sgd(schedule), wrt=nnx.Param)

optax.sgd accepts either a scalar lr or a schedule function — when given a schedule, it reads its own step counter on every update and queries the schedule. This problem doesn’t go that far; we just evaluate the schedule at a single step.

Common pitfalls

  • Passing total_steps to decay_steps. Then the cosine phase is too long and the schedule doesn’t reach end_value by the end of training.
  • Forgetting to cast step to int. The schedule expects an int.
  • Treating warmup as part of decay. Cosine starts AFTER warmup; step 0 always returns init_value.

Problem

Implement lr_at_step(step, peak_lr, warmup_steps, total_steps):

  1. Build schedule = optax.warmup_cosine_decay_schedule(init_value=0.0, peak_value=peak_lr, warmup_steps=int(warmup_steps), decay_steps=int(total_steps) - int(warmup_steps), end_value=0.0).
  2. Evaluate at int(step). Return as 1-D (1,) array jnp.array([float(schedule(int(step)))]).

Pure JAX/optax — no nnx model needed.

Inputs:

  • step: float (cast to int).
  • peak_lr: float — the peak LR reached at end of warmup.
  • warmup_steps: float (cast to int).
  • total_steps: float (cast to int) — warmup_steps + decay_steps.

Output: 1-D (1,)[lr_at_step].

Hints

flax nnx optax schedule warmup cosine

Sign in to attempt this problem and view the solution.