We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
multi_transform per-Param Group
Why this matters
Real-world training rarely applies a uniform learning rate to every
parameter. Fine-tuning LLMs typically uses a 10โ100ร smaller LR for
pre-trained layers vs. freshly initialised heads; some layers are frozen
(lr = 0). optax.multi_transform is the clean primitive for this.
The recipe
optax.multi_transform(transforms, param_labels) applies a different
transform to each parameter group identified by a matching label.
For a 1-D array split into two halves, the idiomatic approach is to convert params into a dict pytree so each half gets its own label:
params_split = {'a': params[:half], 'b': params[half:]}
grads_split = {'a': grads[:half], 'b': grads[half:]}
optimizer = optax.multi_transform(
{'a': optax.sgd(lr_a), 'b': optax.sgd(lr_b)},
{'a': 'a', 'b': 'b'},
)
opt_state = optimizer.init(params_split)
updates, _ = optimizer.update(grads_split, opt_state, params_split)
new_params = optax.apply_updates(params_split, updates)
return jnp.concatenate([new_params['a'], new_params['b']])
Common pitfalls
- Labels must form a pytree that matches the params pytree structure.
- Each unique label needs exactly one entry in the transforms dict.
- With a 1-D array as params, multi_transform sees a single leaf, so you need to split params into a dict first.
Inputs
-
params: 1-D JAX array of even length. -
grads: 1-D JAX array of gradients (same shape). -
lr_a: learning rate for the first half of params. -
lr_b: learning rate for the second half of params.
Output
1-D array โ updated params (first half at lr_a, second half at lr_b).
Hints
Sign in to attempt this problem and view the solution.