We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
medium
primitives
Fix: Python if β jnp.where
Why this matters
Under jax.jit, array values are abstract tracers. Pythonβs if statement
requires a concrete boolean β trying to branch on a traced value raises a
ConcretizationTypeError. This is one of the most common errors beginners
hit when porting code to JAX.
The BROKEN pattern:
@jax.jit
def broken_threshold(x, threshold):
if x[0] > threshold: # ConcretizationTypeError!
return jnp.ones_like(x)
else:
return jnp.zeros_like(x)
Worked mini-example
import jax.numpy as jnp
x = jnp.array([-1.0, 0.5, 2.0, 5.0])
threshold = 1.0
# BROKEN: Python if on tracer β ConcretizationTypeError under jit
# if x[0] > threshold: return jnp.ones_like(x)
# FIXED: element-wise, fully traced, jit-safe
out = jnp.where(x > threshold, 1.0, 0.0)
# β array([0., 0., 1., 1.])
Common pitfalls
-
jnp.whereis element-wise βx > thresholdproduces a boolean array andwhereapplies the condition to every element independently. -
Both branches are always evaluated β
jnp.wherecomputes both the true and false values before selecting; if a branch can produce NaNs, considerjax.lax.condinstead. -
For scalar branching use
lax.condβlax.cond(pred, f_true, f_false)avoids evaluating both functions (unlikejnp.where). -
Boundary values:
x > thresholdis strict; elements equal tothresholdmap to0.0.
Problem
Implement safe_threshold(x, threshold) that returns 1.0 where
x > threshold and 0.0 elsewhere β using jnp.where so it is jit-safe.
-
x: 1-D JAX array. -
threshold: scalar.
Returns: 1-D array the same shape as x with values in {0.0, 1.0}.
Hints
jax
concretization
fix-broken
Sign in to attempt this problem and view the solution.