hard primitives

multi_transform per-Param Group

Why this matters

Real-world training rarely applies a uniform learning rate to every parameter. Fine-tuning LLMs typically uses a 10โ€“100ร— smaller LR for pre-trained layers vs. freshly initialised heads; some layers are frozen (lr = 0). optax.multi_transform is the clean primitive for this.

The recipe

optax.multi_transform(transforms, param_labels) applies a different transform to each parameter group identified by a matching label.

For a 1-D array split into two halves, the idiomatic approach is to convert params into a dict pytree so each half gets its own label:

params_split = {'a': params[:half], 'b': params[half:]}
grads_split  = {'a': grads[:half],  'b': grads[half:]}

optimizer = optax.multi_transform(
    {'a': optax.sgd(lr_a), 'b': optax.sgd(lr_b)},
    {'a': 'a', 'b': 'b'},
)
opt_state = optimizer.init(params_split)
updates, _ = optimizer.update(grads_split, opt_state, params_split)
new_params = optax.apply_updates(params_split, updates)
return jnp.concatenate([new_params['a'], new_params['b']])

Common pitfalls

  • Labels must form a pytree that matches the params pytree structure.
  • Each unique label needs exactly one entry in the transforms dict.
  • With a 1-D array as params, multi_transform sees a single leaf, so you need to split params into a dict first.

Inputs

  • params: 1-D JAX array of even length.
  • grads: 1-D JAX array of gradients (same shape).
  • lr_a: learning rate for the first half of params.
  • lr_b: learning rate for the second half of params.

Output

1-D array โ€” updated params (first half at lr_a, second half at lr_b).

Hints

optax multi-transform masking

Sign in to attempt this problem and view the solution.