We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
zero_nans for NaN-Safe Training
Why this matters
Mixed-precision training (bfloat16) occasionally produces NaN gradients
from numerical overflow or underflow. Rather than aborting the training
step, optax.zero_nans() zeroes out any NaN gradient element and lets
the step proceed. This is a standard defensive transform in production
training pipelines.
The recipe
In real code:
optimizer = optax.chain(optax.zero_nans(), optax.sgd(lr))
This problem uses an explicit mask (0.0/1.0) instead of true NaN values
because JSON cannot represent NaN. The mask-based logic is:
cleaned = jnp.where(mask >= 0.5, 0.0, grads) # zero out "NaN" positions
optimizer = optax.sgd(lr)
opt_state = optimizer.init(params)
updates, _ = optimizer.update(cleaned, opt_state, params)
return optax.apply_updates(params, updates)
Common pitfalls
-
zero_nanszeros NaN gradients but does NOT fix their root cause (e.g. exploding activations). Diagnose the source separately. -
For Inf gradients, use
optax.clip_by_global_normoroptax.adaptive_grad_clipin addition. -
The test contract uses a float mask (0.0/1.0) to simulate NaN positions;
mask >= 0.5identifies them.
Inputs
-
params: 1-D JAX array โ model parameters. -
grads: 1-D JAX array โ gradients (some positions simulating NaN via mask). -
mask: 1-D JAX array of 0.0/1.0 โ 1.0 marks the simulated NaN positions. -
lr: scalar โ SGD learning rate.
Output
1-D array โ params after one SGD step with masked positions zeroed out.
Hints
optax
zero-nans
stability
Sign in to attempt this problem and view the solution.