medium primitives

zero_nans for NaN-Safe Training

Why this matters

Mixed-precision training (bfloat16) occasionally produces NaN gradients from numerical overflow or underflow. Rather than aborting the training step, optax.zero_nans() zeroes out any NaN gradient element and lets the step proceed. This is a standard defensive transform in production training pipelines.

The recipe

In real code:

optimizer = optax.chain(optax.zero_nans(), optax.sgd(lr))

This problem uses an explicit mask (0.0/1.0) instead of true NaN values because JSON cannot represent NaN. The mask-based logic is:

cleaned = jnp.where(mask >= 0.5, 0.0, grads)   # zero out "NaN" positions
optimizer = optax.sgd(lr)
opt_state = optimizer.init(params)
updates, _ = optimizer.update(cleaned, opt_state, params)
return optax.apply_updates(params, updates)

Common pitfalls

  • zero_nans zeros NaN gradients but does NOT fix their root cause (e.g. exploding activations). Diagnose the source separately.
  • For Inf gradients, use optax.clip_by_global_norm or optax.adaptive_grad_clip in addition.
  • The test contract uses a float mask (0.0/1.0) to simulate NaN positions; mask >= 0.5 identifies them.

Inputs

  • params: 1-D JAX array โ€” model parameters.
  • grads: 1-D JAX array โ€” gradients (some positions simulating NaN via mask).
  • mask: 1-D JAX array of 0.0/1.0 โ€” 1.0 marks the simulated NaN positions.
  • lr: scalar โ€” SGD learning rate.

Output

1-D array โ€” params after one SGD step with masked positions zeroed out.

Hints

optax zero-nans stability

Sign in to attempt this problem and view the solution.