medium primitives

Chain: Clip + SGD

Why this matters

Gradient clipping is a staple of training deep networks — especially RNNs and transformers — where exploding gradients can destabilize training. optax.chain lets you compose gradient transforms into a single optimizer, so clipping and the optimizer step happen atomically with a clean API.

The recipe

optax.chain(t1, t2, ...) composes gradient transforms left to right. For clip-then-SGD:

optimizer = optax.chain(
    optax.clip_by_global_norm(max_norm),   # applied first
    optax.sgd(lr),                          # applied second
)

clip_by_global_norm rescales grads so their global L2 norm is at most max_norm. SGD then converts the (possibly clipped) grads into parameter updates: params - lr * clipped_grads.

Common pitfalls

  • Order matters: clip BEFORE the optimizer, not after.
  • clip_by_global_norm operates on gradients, NOT on parameters.
  • Don’t forget optax.apply_updates(params, updates) at the end.

Inputs

  • params: 1-D JAX array of current parameter values.
  • grads: 1-D JAX array of gradients (same shape).
  • lr: learning rate (float).
  • max_norm: maximum allowed global L2 norm for gradients.

Output

1-D array — updated parameter values after one clip + SGD step.

Hints

optax chain clipping

Sign in to attempt this problem and view the solution.