hard primitives

Per-Param Weight Decay Mask

Why this matters

Weight decay (L2 regularization) is one of the cheapest, most effective regularizers — but applying it to everything is wrong. The textbook rule, used in every modern transformer recipe (BERT, GPT, ViT, LLaMA):

Apply weight decay to matrix-shaped params (kernel, embedding, attention). Do NOT apply to bias, LayerNorm γ, LayerNorm β, RMSNorm γ, or any other 1-D parameter.

Why? A few reasons:

  • Biases: shifting the bias toward zero literally just shifts the model’s output distribution. It doesn’t reduce capacity in any meaningful sense, and zero-bias is rarely the optimum.
  • LayerNorm/RMSNorm γ: these scales control feature magnitudes; pushing them toward zero collapses representations to zero. Several papers show training instability when WD is applied here.
  • Embeddings: usually decayed (large param count, easy to overfit).
  • Output Linear bias: usually NOT decayed for the same reason.

Empirical effect: removing WD from biases/norms typically gains 0.1-0.3 % on downstream eval and slightly stabilizes early training. Small but free.

The Optax tool: optax.masked

optax.masked(transformation, mask) runs transformation only on the leaves where mask == True, leaving others untouched. The mask is a pytree of booleans with the same structure as your params:

mask = jax.tree_util.tree_map_with_path(
    lambda path, _v: any(getattr(k, "key", None) == "kernel" for k in path),
    params,
)
# mask is now: {"kernel": True, "bias": False}

Then chain it into your optimizer:

tx = optax.chain(
    optax.masked(optax.add_decayed_weights(wd), mask=mask),
    optax.adam(lr),
)

Why “decoupled” weight decay specifically

Notice we used optax.add_decayed_weights(wd) — not optax.adam(weight_decay=wd). There’s a real distinction:

  • Coupled WD (l2 reg in the loss): loss += wd * ||p||^2. Adam’s adaptive scaling THEN scales the L2 grad differently per param. WD becomes effectively non-uniform.
  • Decoupled WD (add_decayed_weights, a.k.a. AdamW): the decay is applied as a separate update step p -= lr * wd * p, independent of Adam’s preconditioner. Uniform across params.

AdamW (decoupled) is the default in modern recipes. optax.adamw is a convenience wrapper for this exact chain.

tree_map_with_path quick reference

jax.tree_util.tree_map_with_path(fn, tree)

Calls fn(path, leaf) for every leaf. path is a tuple of DictKey(key="...") / SequenceKey(idx=...) / etc. To check whether the path mentions a particular dict key:

any(getattr(k, "key", None) == "kernel" for k in path)

getattr(k, "key", None) gracefully returns None for non-dict path entries, so the comparison just falls through.

Problem

Init nn.Dense(4) with PRNGKey(seed) and x. Build an optimizer that applies decoupled weight decay (optax.add_decayed_weights(wd)) only to the kernel, not the bias, then runs optax.adam(lr).

Run one training step on (x, y) with MSE loss (where y is treated as y.reshape(-1) and matched against model.apply(...).reshape(-1)), then return the post-step loss as [final_loss].

Inputs:

  • seed: float (cast to int).
  • x: 2-D (N, D).
  • y: 1-D (N * 4,) — flattened target matching the Dense(4) output.
  • lr, wd: floats.

Output: 1-D (1,)[loss_after_step].

Hints

flax optax regularization

Sign in to attempt this problem and view the solution.