medium primitives

Multi-Condition where

Why this matters

jnp.where(cond, a, b) is JAX’s vectorized ternary — it selects a where cond is True and b where it is False. Multi-branch decisions arise constantly in ML: piecewise activation functions, reward shaping (clip to range, zero below threshold), masking tokens by type, and label-smoothing thresholds.

Because JAX traces through your code under jit, you cannot use Python if/elif for data-dependent branching — the condition would be a tracer value, not a Python bool. jnp.where (and its sibling jnp.select) are the correct vectorized equivalents.

Worked mini-example

import jax.numpy as jnp

x = jnp.array([-3.0, 0.0, 5.0])
lo, hi = -1.0, 4.0

# 3-way: x < lo → -1; x > hi → 1; else → 0
out = jnp.where(x < lo, -1.0, jnp.where(x > hi, 1.0, 0.0))
# → array([-1., 0., 1.])

For more than 2-3 branches, prefer the flat form:

jnp.select([cond1, cond2, cond3], [v1, v2, v3], default=0.0)

Common pitfalls

  • Python if/elif under jit: raises a ConcretizationTypeError because the condition is a traced abstract value. Use jnp.where or lax.cond.
  • Deep nesting: nesting where 4+ levels deep is hard to read and debug. Switch to jnp.select.
  • Broadcasting gotchas: a and b must be broadcast-compatible with cond. Scalars broadcast fine; mismatched shapes raise an error.
  • Both branches are evaluated: JAX evaluates both the true and false branches eagerly. This can hide NaNs (e.g., jnp.where(x > 0, jnp.log(x), 0) still computes log for x≤0 and may emit NaNs before the mask is applied).

Problem

Implement piecewise_classify(x, lower, upper) that maps each element of x to one of three classes:

  • -1.0 if x < lower
  • 1.0 if x > upper
  • 0.0 otherwise (in the range [lower, upper])

Returns a 1-D array of the same shape as x with values in {-1, 0, 1}.

Two illustrative examples (not from the test set):

  • x = [-5, 0, 5], lower = -2, upper = 2: out = [-1, 0, 1]

  • x = [1, 2, 3], lower = 0, upper = 2: out = [0, 0, 1]

Hints

jax where select branching

Sign in to attempt this problem and view the solution.