easy primitives

NaN-Safe Mean (with mask)

Why this matters

NaN values are a production hazard in ML pipelines. A single NaN in a loss value propagates backward and corrupts all gradients — silently, if you’re not monitoring. Common sources: log of a non-positive number, 0/0 division, numerical overflow in softmax, or reading a masked-out position in a sequence.

JAX provides several NaN-safe utilities:

  • jnp.isnan(x) — boolean mask of NaN positions.
  • jnp.nan_to_num(x, nan=0.0) — replace NaN with a fill value.
  • jnp.nansum(x) / jnp.nanmean(x) — reduction that ignores NaN.
  • The jnp.where(valid_mask, x, 0.0) + explicit count pattern — most flexible, works under jit, lets you choose how to handle the edge case when all values are masked.

Worked mini-example

import jax.numpy as jnp

x        = jnp.array([1.0, float('nan'), 3.0, float('nan')])
is_valid = ~jnp.isnan(x)                           # [T, F, T, F]
cleaned  = jnp.where(is_valid, x, 0.0)             # [1, 0, 3, 0]
n_valid  = jnp.sum(is_valid.astype(jnp.float32))   # 2.0
mean     = jnp.sum(cleaned) / jnp.maximum(n_valid, 1.0)
# → 2.0

The jnp.maximum(n_valid, 1.0) guard prevents division-by-zero when all positions are masked, returning 0 instead of NaN.

Common pitfalls

  • Forgetting the all-masked edge case: if every element is masked, n_valid = 0 and you get 0/0 = NaN. Protect with jnp.maximum(n, 1.0).
  • NaN in JSON test cases: JSON cannot represent NaN or Infinity, so this problem uses an explicit nan_mask argument instead. Real code uses jnp.isnan(x).
  • jnp.nanmean limitations: convenient but doesn’t give you control over the all-masked edge case; use the explicit pattern for production code.
  • Both branches in where: jnp.where(cond, x, 0.0) still evaluates x everywhere. If x contains NaN at masked positions, cleaned will be 0 at those positions (the NaN is overwritten by the select) — this is fine and intended.

Problem

Implement safe_mean_with_nan(x, nan_mask) that computes the mean of x excluding positions where nan_mask >= 0.5 (treating those as NaN).

  • Returns a scalar.
  • If all positions are masked, return 0.0 (not NaN).

The nan_mask convention: 0.0 = valid, 1.0 = masked (pretend NaN). In real code you would use jnp.isnan(x) directly.

Two illustrative examples (not from the test set):

  • x = [2.0, 4.0, 6.0], nan_mask = [0.0, 1.0, 0.0]: valid values = [2, 6]; mean = 4.0.

  • x = [5.0], nan_mask = [1.0]: all masked; return 0.0.

Hints

jax nan numerical-stability

Sign in to attempt this problem and view the solution.