We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Per-Param Learning Rate Multipliers
Why this matters
A single learning rate is rarely optimal for the whole model. The most common case where this bites:
Fine-tuning: the pretrained trunk wants a small LR (don’t forget!), the freshly-initialized head wants a big LR (catch up fast). 10× LR on the head is a typical ratio.
Other cases:
- Embeddings vs. transformer blocks: embeddings often need lower LR to avoid catastrophic drift.
- Bias / LayerNorm γ: some recipes use 2× LR on biases since they’re 1-D and have less “inertia” than full matrices.
- Decoder-only LLMs: the LM head is often tied to embeddings (shared weights), but if not tied, often gets a different LR.
PyTorch handles this via parameter groups in the optimizer:
torch.optim.SGD([
{"params": trunk.parameters(), "lr": 1e-5},
{"params": head.parameters(), "lr": 1e-3},
])
Flax handles it via optax.multi_transform.
How optax.multi_transform works
optax.multi_transform(
transforms={"label_a": tx_a, "label_b": tx_b, ...},
param_labels=labels_pytree,
)
-
transforms: dict mapping label string to aGradientTransformation. -
param_labels: a pytree with the same structure asparams, where each leaf is a label string indicating which transform to apply.
Each leaf gets routed to its labeled transform; updates are applied
independently per group. So optax.multi_transform({"kernel": optax.sgd(0.01), "bias": optax.sgd(0.1)}, labels) runs two separate SGD optimizers,
one for kernels (lr=0.01) and one for biases (lr=0.1).
Building the labels pytree
jax.tree_util.tree_map_with_path is the canonical tool. For the
“kernel vs bias” split:
labels = jax.tree_util.tree_map_with_path(
lambda path, _v: "kernel" if any(getattr(k, "key", None) == "kernel" for k in path) else "bias",
params,
)
# labels is now: {"kernel": "kernel", "bias": "bias"}
The leaves of labels are strings, matching the keys of the
transforms dict. Optax checks that every label is present in
transforms — a typo will raise.
When you’d use 3+ groups
A common ViT fine-tuning recipe has THREE groups:
-
Pretrained backbone:
lr * 0.1(slow drift). -
Pretrained classifier head (if exists):
lr * 1.0. -
Newly-initialized layers:
lr * 10.0.
multi_transform scales naturally: just add more entries to both
dicts.
Pitfalls
-
Label typo:
transformsdoesn’t contain that label → cryptic error. Spell-check. -
Forgetting
int(steps): doesn’t apply here, but a common mistake elsewhere (Flax static fields can’t be JAX scalars). -
Different label structures across runs:
multi_transformcaches based on label structure. If labels change, you must build a fresh optimizer. (Easy if you’re following the standard “init optimizer once” pattern.)
Problem
Init nn.Dense(4) with PRNGKey(seed) and x. Build an optimizer
that runs optax.sgd(lr_kernel) on kernel leaves and
optax.sgd(lr_bias) on bias leaves via optax.multi_transform.
Run one MSE training step on (x, y) (where y is treated as
y.reshape(-1) against model.apply(...).reshape(-1)), and return
[final_loss] as a 1-D (1,).
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D). -
y: 1-D(N * 4,)— flattened target matching the Dense(4) output. -
lr_kernel,lr_bias: floats — separate learning rates per group.
Output: 1-D (1,) — [loss_after_step].
Hints
Sign in to attempt this problem and view the solution.