We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Multi-Condition where
Why this matters
jnp.where(cond, a, b) is JAX’s vectorized ternary — it selects a where
cond is True and b where it is False. Multi-branch decisions arise
constantly in ML: piecewise activation functions, reward shaping (clip to
range, zero below threshold), masking tokens by type, and label-smoothing
thresholds.
Because JAX traces through your code under jit, you cannot use Python
if/elif for data-dependent branching — the condition would be a tracer
value, not a Python bool. jnp.where (and its sibling jnp.select) are the
correct vectorized equivalents.
Worked mini-example
import jax.numpy as jnp
x = jnp.array([-3.0, 0.0, 5.0])
lo, hi = -1.0, 4.0
# 3-way: x < lo → -1; x > hi → 1; else → 0
out = jnp.where(x < lo, -1.0, jnp.where(x > hi, 1.0, 0.0))
# → array([-1., 0., 1.])
For more than 2-3 branches, prefer the flat form:
jnp.select([cond1, cond2, cond3], [v1, v2, v3], default=0.0)
Common pitfalls
-
Python if/elif under jit: raises a
ConcretizationTypeErrorbecause the condition is a traced abstract value. Usejnp.whereorlax.cond. -
Deep nesting: nesting
where4+ levels deep is hard to read and debug. Switch tojnp.select. -
Broadcasting gotchas:
aandbmust be broadcast-compatible withcond. Scalars broadcast fine; mismatched shapes raise an error. -
Both branches are evaluated: JAX evaluates both the true and false
branches eagerly. This can hide NaNs (e.g.,
jnp.where(x > 0, jnp.log(x), 0)still computeslogfor x≤0 and may emit NaNs before the mask is applied).
Problem
Implement piecewise_classify(x, lower, upper) that maps each element of
x to one of three classes:
-
-1.0ifx < lower -
1.0ifx > upper -
0.0otherwise (in the range [lower, upper])
Returns a 1-D array of the same shape as x with values in {-1, 0, 1}.
Two illustrative examples (not from the test set):
-
x = [-5, 0, 5],lower = -2,upper = 2:out = [-1, 0, 1] -
x = [1, 2, 3],lower = 0,upper = 2:out = [0, 0, 1]
Hints
Sign in to attempt this problem and view the solution.