We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
NaN-Safe Mean (with mask)
Why this matters
NaN values are a production hazard in ML pipelines. A single NaN in a loss value propagates backward and corrupts all gradients — silently, if you’re not monitoring. Common sources: log of a non-positive number, 0/0 division, numerical overflow in softmax, or reading a masked-out position in a sequence.
JAX provides several NaN-safe utilities:
-
jnp.isnan(x)— boolean mask of NaN positions. -
jnp.nan_to_num(x, nan=0.0)— replace NaN with a fill value. -
jnp.nansum(x)/jnp.nanmean(x)— reduction that ignores NaN. -
The
jnp.where(valid_mask, x, 0.0)+ explicit count pattern — most flexible, works underjit, lets you choose how to handle the edge case when all values are masked.
Worked mini-example
import jax.numpy as jnp
x = jnp.array([1.0, float('nan'), 3.0, float('nan')])
is_valid = ~jnp.isnan(x) # [T, F, T, F]
cleaned = jnp.where(is_valid, x, 0.0) # [1, 0, 3, 0]
n_valid = jnp.sum(is_valid.astype(jnp.float32)) # 2.0
mean = jnp.sum(cleaned) / jnp.maximum(n_valid, 1.0)
# → 2.0
The jnp.maximum(n_valid, 1.0) guard prevents division-by-zero when all
positions are masked, returning 0 instead of NaN.
Common pitfalls
-
Forgetting the all-masked edge case: if every element is masked,
n_valid = 0and you get0/0 = NaN. Protect withjnp.maximum(n, 1.0). -
NaN in JSON test cases: JSON cannot represent
NaNorInfinity, so this problem uses an explicitnan_maskargument instead. Real code usesjnp.isnan(x). -
jnp.nanmeanlimitations: convenient but doesn’t give you control over the all-masked edge case; use the explicit pattern for production code. -
Both branches in where:
jnp.where(cond, x, 0.0)still evaluatesxeverywhere. Ifxcontains NaN at masked positions,cleanedwill be 0 at those positions (the NaN is overwritten by the select) — this is fine and intended.
Problem
Implement safe_mean_with_nan(x, nan_mask) that computes the mean of x
excluding positions where nan_mask >= 0.5 (treating those as NaN).
- Returns a scalar.
-
If all positions are masked, return
0.0(not NaN).
The nan_mask convention: 0.0 = valid, 1.0 = masked (pretend NaN).
In real code you would use jnp.isnan(x) directly.
Two illustrative examples (not from the test set):
-
x = [2.0, 4.0, 6.0],nan_mask = [0.0, 1.0, 0.0]: valid values = [2, 6]; mean = 4.0. -
x = [5.0],nan_mask = [1.0]: all masked; return 0.0.
Hints
Sign in to attempt this problem and view the solution.