easy primitives

Linear Schedule

Why this matters

optax.linear_schedule(init_value, end_value, transition_steps) linearly interpolates from init_value to end_value over transition_steps, then clamps at end_value forever after. It is the building block for linear warmup (ramp from 0 to peak_lr) and linear decay (ramp from peak_lr down to 0).

The schedule behaviour

For step t and T = transition_steps:

t < T  โ†’  init_value + (end_value - init_value) * t / T
t >= T โ†’  end_value  (clamped)

Example (init=1.0, end=0.0, T=100):

step LR
0 1.0
50 0.5
100 0.0
150 0.0

Common pitfalls

  • After transition_steps the value is clamped at end_value โ€” it does not wrap around or reset.
  • For a linear warmup, set init_value=0.0 and end_value=peak_lr.
  • Cast both step and transition_steps to int.

Inputs

  • step: scalar (cast to int).
  • init_value: starting LR.
  • end_value: final LR (clamped after transition_steps).
  • transition_steps: number of steps over which to interpolate (cast to int).

Output

Scalar โ€” the LR at step.

Hints

optax schedule linear

Sign in to attempt this problem and view the solution.