medium primitives

AdamW with Decoupled Weight Decay

Why this matters

AdamW is the standard optimizer for transformer training (BERT, GPT, T5, LLaMA). The key difference from Adam+L2 is decoupled weight decay: the regularisation term is applied directly to the parameters rather than being added to the gradients before the adaptive step.

Decoupled weight decay

Adam + L2 regularisation adds wd * params to the gradients before the adaptive scaling, which means the effective regularisation strength depends on the gradient magnitude. AdamW sidesteps this by applying decay directly:

# Adam step (on grads only):
params_adam = params - lr * adam_update(grads)

# Decoupled weight decay applied separately:
params_new = params_adam - lr * wd * params

When wd > 0 and grads = 0, the parameter still shrinks by lr * wd per step โ€” useful for pulling unused weights toward zero.

Typical hyperparameters

Use case lr weight_decay
BERT fine-tune 2e-5 0.01
GPT pre-train 3e-4 0.1
ViT 1e-3 0.3

Optax usage

optimizer = optax.adamw(lr, weight_decay=wd)
opt_state = optimizer.init(params)
updates, new_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)

Common pitfalls

  • Adam vs AdamW: using optax.adam with manually added L2 in the loss is NOT the same as optax.adamw โ€” the gradient-adaptive scaling makes them differ.
  • wd=0.0 gives Adam: weight decay of zero degenerates to pure Adam.
  • Zero grads still decay: setting grads=0 with wd>0 shrinks params by lr * wd per step.

Inputs

  • params: 1-D array of current parameter values.
  • grads: 1-D array of gradients, same shape as params.
  • lr: scalar learning rate.
  • wd: scalar weight decay coefficient.

Output

1-D array โ€” updated parameter values after one AdamW step.

Hints

optax adamw weight-decay

Sign in to attempt this problem and view the solution.