medium primitives

Adam + Weight Decay (Chain)

Why this matters

Weight decay (L2 regularization) is one of the most widely used regularization techniques. In practice it is nearly always paired with Adam. Understanding how to compose it via optax.chain also deepens your understanding of what optax.adamw does (and doesn’t do) under the hood.

The recipe

optax.add_decayed_weights(wd) adds wd * params to the gradient before the optimizer sees it — effectively implementing L2 regularization in the gradient space:

effective_grad = original_grad + wd * params

Chaining it before Adam gives the gradient-coupled weight decay variant:

optimizer = optax.chain(
    optax.add_decayed_weights(wd),
    optax.adam(lr),
)

Common pitfalls

  • This is gradient-coupled weight decay (L2-style), NOT the decoupled weight decay that optax.adamw implements. For true decoupled AdamW use optax.adamw(lr, weight_decay=wd).
  • add_decayed_weights modifies grads (adds wd * params), so the order in the chain matters.

Inputs

  • params: 1-D JAX array of current parameter values.
  • grads: 1-D JAX array of gradients (same shape).
  • lr: Adam learning rate.
  • wd: weight decay coefficient.

Output

1-D array — updated parameter values.

Hints

optax weight-decay chain

Sign in to attempt this problem and view the solution.