medium primitives

Global-Norm Gradient Clipping

Why this matters

Gradient clipping by global L2 norm is the standard stability technique for training RNNs, transformers, and any deep network prone to exploding gradients. It is used standalone (without a full optimizer) when you want to inspect or transform gradients before passing them to a different system.

The formula

Given grads with global L2 norm g = sqrt(sum(g_i^2)):

  • If g > max_norm: scale ALL gradients by max_norm / g โ†’ exact norm.
  • If g <= max_norm: return grads unchanged.

optax.clip_by_global_norm(max_norm) implements this as a standalone gradient transform with the familiar init / update interface.

Common pitfalls

  • Per-param clipping is different: use optax.clip(max_delta) to clip each element individually instead of scaling the whole vector.
  • Per-block clipping: optax.clip_by_block_rms clips per layer by RMS.
  • The transform state is stateless (empty), so init just returns an empty EmptyState; only update does the real work.

Inputs

  • grads: 1-D JAX array of gradient values.
  • max_norm: maximum allowed global L2 norm (float).

Output

1-D array โ€” clipped (or unchanged) gradient values.

Hints

optax clipping global-norm

Sign in to attempt this problem and view the solution.