medium primitives

Fix: Python if β†’ jnp.where

Why this matters

Under jax.jit, array values are abstract tracers. Python’s if statement requires a concrete boolean β€” trying to branch on a traced value raises a ConcretizationTypeError. This is one of the most common errors beginners hit when porting code to JAX.

The BROKEN pattern:

@jax.jit
def broken_threshold(x, threshold):
    if x[0] > threshold:          # ConcretizationTypeError!
        return jnp.ones_like(x)
    else:
        return jnp.zeros_like(x)

Worked mini-example

import jax.numpy as jnp

x = jnp.array([-1.0, 0.5, 2.0, 5.0])
threshold = 1.0

# BROKEN: Python if on tracer β†’ ConcretizationTypeError under jit
# if x[0] > threshold: return jnp.ones_like(x)

# FIXED: element-wise, fully traced, jit-safe
out = jnp.where(x > threshold, 1.0, 0.0)
# β†’ array([0., 0., 1., 1.])

Common pitfalls

  • jnp.where is element-wise β€” x > threshold produces a boolean array and where applies the condition to every element independently.
  • Both branches are always evaluated β€” jnp.where computes both the true and false values before selecting; if a branch can produce NaNs, consider jax.lax.cond instead.
  • For scalar branching use lax.cond β€” lax.cond(pred, f_true, f_false) avoids evaluating both functions (unlike jnp.where).
  • Boundary values: x > threshold is strict; elements equal to threshold map to 0.0.

Problem

Implement safe_threshold(x, threshold) that returns 1.0 where x > threshold and 0.0 elsewhere β€” using jnp.where so it is jit-safe.

  • x: 1-D JAX array.
  • threshold: scalar.

Returns: 1-D array the same shape as x with values in {0.0, 1.0}.

Hints

jax concretization fix-broken

Sign in to attempt this problem and view the solution.