medium primitives

inject_hyperparams for Runtime LR

Why this matters

When you jit-compile a training step that uses a fixed learning rate, changing that LR triggers a full recompile because the value is baked in as a compile-time constant. optax.inject_hyperparams solves this by treating specified hyperparameters as runtime values that can change each step without retracing.

The recipe

optimizer = optax.inject_hyperparams(optax.adam)(learning_rate=lr)
opt_state = optimizer.init(params)
updates, _ = optimizer.update(grads, opt_state, params)
return optax.apply_updates(params, updates)

The behaviour is numerically identical to plain optax.adam(lr) for a single step. The difference shows up when you pass different lr values across steps inside a jitted function — no recompile needed.

Common pitfalls

  • The hyperparam name must match the optimizer constructor’s keyword argument: use learning_rate=lr, not lr=lr.
  • inject_hyperparams works with any optimizer constructor, not just Adam.
  • The wrapped state stores the hyperparams so they can be updated in-place.

Inputs

  • params: 1-D JAX array — model parameters.
  • grads: 1-D JAX array — gradients.
  • lr: scalar — learning rate passed at runtime.

Output

1-D array — params after one Adam step at the given lr.

Hints

optax inject-hyperparams

Sign in to attempt this problem and view the solution.