We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
Adam + Weight Decay (Chain)
Why this matters
Weight decay (L2 regularization) is one of the most widely used
regularization techniques. In practice it is nearly always paired with
Adam. Understanding how to compose it via optax.chain also deepens your
understanding of what optax.adamw does (and doesn’t do) under the hood.
The recipe
optax.add_decayed_weights(wd) adds wd * params to the gradient before
the optimizer sees it — effectively implementing L2 regularization in the
gradient space:
effective_grad = original_grad + wd * params
Chaining it before Adam gives the gradient-coupled weight decay variant:
optimizer = optax.chain(
optax.add_decayed_weights(wd),
optax.adam(lr),
)
Common pitfalls
-
This is gradient-coupled weight decay (L2-style), NOT the
decoupled weight decay that
optax.adamwimplements. For true decoupled AdamW useoptax.adamw(lr, weight_decay=wd). -
add_decayed_weightsmodifies grads (adds wd * params), so the order in the chain matters.
Inputs
-
params: 1-D JAX array of current parameter values. -
grads: 1-D JAX array of gradients (same shape). -
lr: Adam learning rate. -
wd: weight decay coefficient.
Output
1-D array — updated parameter values.
Hints
optax
weight-decay
chain
Sign in to attempt this problem and view the solution.