hard primitives

Train Step with Warmup Schedule

Why this matters

Starting training with a large learning rate can destabilise early updates when gradients are noisy. A linear warmup schedule ramps the LR from 0 to peak_lr over warmup_steps steps, then holds or decays. This problem covers the core pattern: evaluate the schedule at a given step, build an optimizer with that LR, and take one SGD step.

The recipe

schedule   = optax.linear_schedule(0.0, peak_lr, int(warmup_steps))
current_lr = float(schedule(int(step)))
optimizer  = optax.sgd(current_lr)
opt_state  = optimizer.init(params)
updates, _ = optimizer.update(grads, opt_state, params)
return optax.apply_updates(params, updates)

For jit-compiled loops, use optax.inject_hyperparams to avoid retrace on each step (the schedule LR becomes a runtime value, not a constant).

Common pitfalls

  • schedule(step) returns a JAX scalar; cast to float before passing to optax.sgd to avoid a tracer leak.
  • The schedule is purely time-driven; it does not depend on the gradients.
  • At step=0, linear_schedule(0.0, ...) returns exactly 0.0 โ†’ no update.

Inputs

  • params: 1-D JAX array โ€” model parameters.
  • grads: 1-D JAX array โ€” gradients.
  • step: scalar โ€” current training step (0-indexed).
  • peak_lr: scalar โ€” target learning rate at end of warmup.
  • warmup_steps: scalar โ€” number of steps to warm up over.

Output

1-D array โ€” params after one SGD step at the scheduled learning rate.

Hints

optax training warmup schedule

Sign in to attempt this problem and view the solution.