We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
hard
primitives
Train Step with Warmup Schedule
Why this matters
Starting training with a large learning rate can destabilise early updates
when gradients are noisy. A linear warmup schedule ramps the LR from 0
to peak_lr over warmup_steps steps, then holds or decays. This problem
covers the core pattern: evaluate the schedule at a given step, build an
optimizer with that LR, and take one SGD step.
The recipe
schedule = optax.linear_schedule(0.0, peak_lr, int(warmup_steps))
current_lr = float(schedule(int(step)))
optimizer = optax.sgd(current_lr)
opt_state = optimizer.init(params)
updates, _ = optimizer.update(grads, opt_state, params)
return optax.apply_updates(params, updates)
For jit-compiled loops, use optax.inject_hyperparams to avoid retrace
on each step (the schedule LR becomes a runtime value, not a constant).
Common pitfalls
-
schedule(step)returns a JAX scalar; cast tofloatbefore passing tooptax.sgdto avoid a tracer leak. - The schedule is purely time-driven; it does not depend on the gradients.
-
At
step=0,linear_schedule(0.0, ...)returns exactly 0.0 โ no update.
Inputs
-
params: 1-D JAX array โ model parameters. -
grads: 1-D JAX array โ gradients. -
step: scalar โ current training step (0-indexed). -
peak_lr: scalar โ target learning rate at end of warmup. -
warmup_steps: scalar โ number of steps to warm up over.
Output
1-D array โ params after one SGD step at the scheduled learning rate.
Hints
optax
training
warmup
schedule
Sign in to attempt this problem and view the solution.