We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Chain: Clip + SGD
Why this matters
Gradient clipping is a staple of training deep networks — especially RNNs
and transformers — where exploding gradients can destabilize training.
optax.chain lets you compose gradient transforms into a single optimizer,
so clipping and the optimizer step happen atomically with a clean API.
The recipe
optax.chain(t1, t2, ...) composes gradient transforms left to right.
For clip-then-SGD:
optimizer = optax.chain(
optax.clip_by_global_norm(max_norm), # applied first
optax.sgd(lr), # applied second
)
clip_by_global_norm rescales grads so their global L2 norm is at most
max_norm. SGD then converts the (possibly clipped) grads into parameter
updates: params - lr * clipped_grads.
Common pitfalls
- Order matters: clip BEFORE the optimizer, not after.
-
clip_by_global_normoperates on gradients, NOT on parameters. -
Don’t forget
optax.apply_updates(params, updates)at the end.
Inputs
-
params: 1-D JAX array of current parameter values. -
grads: 1-D JAX array of gradients (same shape). -
lr: learning rate (float). -
max_norm: maximum allowed global L2 norm for gradients.
Output
1-D array — updated parameter values after one clip + SGD step.
Hints
Sign in to attempt this problem and view the solution.