We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
inject_hyperparams for Runtime LR
Why this matters
When you jit-compile a training step that uses a fixed learning rate,
changing that LR triggers a full recompile because the value is baked in
as a compile-time constant. optax.inject_hyperparams solves this by
treating specified hyperparameters as runtime values that can change each
step without retracing.
The recipe
optimizer = optax.inject_hyperparams(optax.adam)(learning_rate=lr)
opt_state = optimizer.init(params)
updates, _ = optimizer.update(grads, opt_state, params)
return optax.apply_updates(params, updates)
The behaviour is numerically identical to plain optax.adam(lr) for a
single step. The difference shows up when you pass different lr values
across steps inside a jitted function — no recompile needed.
Common pitfalls
-
The hyperparam name must match the optimizer constructor’s keyword
argument: use
learning_rate=lr, notlr=lr. -
inject_hyperparamsworks with any optimizer constructor, not just Adam. - The wrapped state stores the hyperparams so they can be updated in-place.
Inputs
-
params: 1-D JAX array — model parameters. -
grads: 1-D JAX array — gradients. -
lr: scalar — learning rate passed at runtime.
Output
1-D array — params after one Adam step at the given lr.
Hints
Sign in to attempt this problem and view the solution.