We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
hard
primitives
Train Step with Global-Norm Clipping
Why this matters
Gradient explosion is a pervasive problem in RNN and transformer training.
optax.clip_by_global_norm rescales the entire gradient pytree so its
L2-norm does not exceed max_norm, then the optimizer applies its update
rule. Chaining these two transforms is the standard defensive pattern used
in BERT, GPT, and most production JAX training loops.
The recipe
optimizer = optax.chain(
optax.clip_by_global_norm(max_norm),
optax.sgd(lr),
)
opt_state = optimizer.init(params)
updates, _ = optimizer.update(grads, opt_state, params)
return optax.apply_updates(params, updates)
Common pitfalls
-
Clip must come first in
chain; clipping after SGD is meaningless. -
clip_by_global_normoperates on the global L2-norm of all gradient tensors together, not per-tensor. -
max_norm=1.0is the conventional value for BERT/GPT-class training. -
When
global_norm <= max_norm, the gradients are passed through unchanged.
Inputs
-
params: 1-D JAX array โ model parameters. -
grads: 1-D JAX array โ gradients. -
lr: scalar โ SGD learning rate. -
max_norm: scalar โ maximum allowed gradient global norm.
Output
1-D array โ params after one clipped-SGD step.
Hints
optax
training
clipping
Sign in to attempt this problem and view the solution.