We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Per-Param Weight Decay Mask
Why this matters
Weight decay (L2 regularization) is one of the cheapest, most effective regularizers — but applying it to everything is wrong. The textbook rule, used in every modern transformer recipe (BERT, GPT, ViT, LLaMA):
Apply weight decay to matrix-shaped params (
kernel,embedding,attention). Do NOT apply to bias, LayerNorm γ, LayerNorm β, RMSNorm γ, or any other 1-D parameter.
Why? A few reasons:
- Biases: shifting the bias toward zero literally just shifts the model’s output distribution. It doesn’t reduce capacity in any meaningful sense, and zero-bias is rarely the optimum.
- LayerNorm/RMSNorm γ: these scales control feature magnitudes; pushing them toward zero collapses representations to zero. Several papers show training instability when WD is applied here.
- Embeddings: usually decayed (large param count, easy to overfit).
- Output Linear bias: usually NOT decayed for the same reason.
Empirical effect: removing WD from biases/norms typically gains 0.1-0.3 % on downstream eval and slightly stabilizes early training. Small but free.
The Optax tool: optax.masked
optax.masked(transformation, mask) runs transformation only on
the leaves where mask == True, leaving others untouched. The
mask is a pytree of booleans with the same structure as your
params:
mask = jax.tree_util.tree_map_with_path(
lambda path, _v: any(getattr(k, "key", None) == "kernel" for k in path),
params,
)
# mask is now: {"kernel": True, "bias": False}
Then chain it into your optimizer:
tx = optax.chain(
optax.masked(optax.add_decayed_weights(wd), mask=mask),
optax.adam(lr),
)
Why “decoupled” weight decay specifically
Notice we used optax.add_decayed_weights(wd) — not
optax.adam(weight_decay=wd). There’s a real distinction:
-
Coupled WD (l2 reg in the loss):
loss += wd * ||p||^2. Adam’s adaptive scaling THEN scales the L2 grad differently per param. WD becomes effectively non-uniform. -
Decoupled WD (
add_decayed_weights, a.k.a. AdamW): the decay is applied as a separate update stepp -= lr * wd * p, independent of Adam’s preconditioner. Uniform across params.
AdamW (decoupled) is the default in modern recipes. optax.adamw
is a convenience wrapper for this exact chain.
tree_map_with_path quick reference
jax.tree_util.tree_map_with_path(fn, tree)
Calls fn(path, leaf) for every leaf. path is a tuple of
DictKey(key="...") / SequenceKey(idx=...) / etc. To check
whether the path mentions a particular dict key:
any(getattr(k, "key", None) == "kernel" for k in path)
getattr(k, "key", None) gracefully returns None for non-dict
path entries, so the comparison just falls through.
Problem
Init nn.Dense(4) with PRNGKey(seed) and x. Build an optimizer
that applies decoupled weight decay (optax.add_decayed_weights(wd))
only to the kernel, not the bias, then runs optax.adam(lr).
Run one training step on (x, y) with MSE loss (where y is treated
as y.reshape(-1) and matched against model.apply(...).reshape(-1)),
then return the post-step loss as [final_loss].
Inputs:
-
seed: float (cast to int). -
x: 2-D(N, D). -
y: 1-D(N * 4,)— flattened target matching the Dense(4) output. -
lr,wd: floats.
Output: 1-D (1,) — [loss_after_step].
Hints
Sign in to attempt this problem and view the solution.