We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
AdamW with Decoupled Weight Decay
Why this matters
AdamW is the standard optimizer for transformer training (BERT, GPT, T5, LLaMA). The key difference from Adam+L2 is decoupled weight decay: the regularisation term is applied directly to the parameters rather than being added to the gradients before the adaptive step.
Decoupled weight decay
Adam + L2 regularisation adds wd * params to the gradients before
the adaptive scaling, which means the effective regularisation strength
depends on the gradient magnitude. AdamW sidesteps this by applying decay
directly:
# Adam step (on grads only):
params_adam = params - lr * adam_update(grads)
# Decoupled weight decay applied separately:
params_new = params_adam - lr * wd * params
When wd > 0 and grads = 0, the parameter still shrinks by lr * wd
per step โ useful for pulling unused weights toward zero.
Typical hyperparameters
| Use case | lr | weight_decay |
|---|---|---|
| BERT fine-tune | 2e-5 | 0.01 |
| GPT pre-train | 3e-4 | 0.1 |
| ViT | 1e-3 | 0.3 |
Optax usage
optimizer = optax.adamw(lr, weight_decay=wd)
opt_state = optimizer.init(params)
updates, new_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
Common pitfalls
-
Adam vs AdamW: using
optax.adamwith manually added L2 in the loss is NOT the same asoptax.adamwโ the gradient-adaptive scaling makes them differ. -
wd=0.0gives Adam: weight decay of zero degenerates to pure Adam. -
Zero grads still decay: setting
grads=0withwd>0shrinks params bylr * wdper step.
Inputs
-
params: 1-D array of current parameter values. -
grads: 1-D array of gradients, same shape asparams. -
lr: scalar learning rate. -
wd: scalar weight decay coefficient.
Output
1-D array โ updated parameter values after one AdamW step.
Hints
Sign in to attempt this problem and view the solution.