hard primitives

Per-Param Learning Rate Multipliers

Why this matters

A single learning rate is rarely optimal for the whole model. The most common case where this bites:

Fine-tuning: the pretrained trunk wants a small LR (don’t forget!), the freshly-initialized head wants a big LR (catch up fast). 10× LR on the head is a typical ratio.

Other cases:

  • Embeddings vs. transformer blocks: embeddings often need lower LR to avoid catastrophic drift.
  • Bias / LayerNorm γ: some recipes use 2× LR on biases since they’re 1-D and have less “inertia” than full matrices.
  • Decoder-only LLMs: the LM head is often tied to embeddings (shared weights), but if not tied, often gets a different LR.

PyTorch handles this via parameter groups in the optimizer:

torch.optim.SGD([
    {"params": trunk.parameters(),  "lr": 1e-5},
    {"params": head.parameters(),   "lr": 1e-3},
])

Flax handles it via optax.multi_transform.

How optax.multi_transform works

optax.multi_transform(
    transforms={"label_a": tx_a, "label_b": tx_b, ...},
    param_labels=labels_pytree,
)
  • transforms: dict mapping label string to a GradientTransformation.
  • param_labels: a pytree with the same structure as params, where each leaf is a label string indicating which transform to apply.

Each leaf gets routed to its labeled transform; updates are applied independently per group. So optax.multi_transform({"kernel": optax.sgd(0.01), "bias": optax.sgd(0.1)}, labels) runs two separate SGD optimizers, one for kernels (lr=0.01) and one for biases (lr=0.1).

Building the labels pytree

jax.tree_util.tree_map_with_path is the canonical tool. For the “kernel vs bias” split:

labels = jax.tree_util.tree_map_with_path(
    lambda path, _v: "kernel" if any(getattr(k, "key", None) == "kernel" for k in path) else "bias",
    params,
)
# labels is now: {"kernel": "kernel", "bias": "bias"}

The leaves of labels are strings, matching the keys of the transforms dict. Optax checks that every label is present in transforms — a typo will raise.

When you’d use 3+ groups

A common ViT fine-tuning recipe has THREE groups:

  1. Pretrained backbone: lr * 0.1 (slow drift).
  2. Pretrained classifier head (if exists): lr * 1.0.
  3. Newly-initialized layers: lr * 10.0.

multi_transform scales naturally: just add more entries to both dicts.

Pitfalls

  • Label typo: transforms doesn’t contain that label → cryptic error. Spell-check.
  • Forgetting int(steps): doesn’t apply here, but a common mistake elsewhere (Flax static fields can’t be JAX scalars).
  • Different label structures across runs: multi_transform caches based on label structure. If labels change, you must build a fresh optimizer. (Easy if you’re following the standard “init optimizer once” pattern.)

Problem

Init nn.Dense(4) with PRNGKey(seed) and x. Build an optimizer that runs optax.sgd(lr_kernel) on kernel leaves and optax.sgd(lr_bias) on bias leaves via optax.multi_transform.

Run one MSE training step on (x, y) (where y is treated as y.reshape(-1) against model.apply(...).reshape(-1)), and return [final_loss] as a 1-D (1,).

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D).
  • y: 1-D (N * 4,) — flattened target matching the Dense(4) output.
  • lr_kernel, lr_bias: floats — separate learning rates per group.

Output: 1-D (1,)[loss_after_step].

Hints

flax optax learning-rate

Sign in to attempt this problem and view the solution.