We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Binary Classification via lax.cond
Why this matters
lax.cond(pred, true_fn, false_fn) is JAX’s jit-safe conditional — the
functional replacement for Python’s if/else inside traced code. Under
jit, Python if statements work only when the condition is a static Python
bool. As soon as the predicate depends on a JAX array (e.g. the result of
jnp.sum), the branch must be expressed with lax.cond or jnp.where.
Unlike jnp.where, which materialises both branches for every element,
lax.cond ensures only one branch executes at runtime — though both are
traced at compile time. This matters for side effects (none allowed), but
means you get proper short-circuit semantics without computing the unused path
on accelerators.
Worked mini-example
from jax import lax
import jax.numpy as jnp
def sign(x):
# Returns 1.0 if x > 0, else -1.0 — jit-safe
return lax.cond(x > 0, lambda: 1.0, lambda: -1.0)
sign(jnp.array(3.0)) # → 1.0
sign(jnp.array(-2.0)) # → -1.0
true_fn and false_fn are zero-argument callables that return the branch
value. Both must produce identical shape and dtype — XLA lowers both to a
select instruction, so mismatched types cause a trace-time error.
Common pitfalls
-
Python if under jit.
if jnp.sum(x) > threshold:raises aConcretizationTypeErrorunderjit. Uselax.condinstead. -
Shape/dtype mismatch.
lambda: 1(int) vslambda: 1.0(float) will silently work in eager mode but may fail or produce wrong dtypes; always use matching literals (1.0/-1.0). - Both branches traced. Any Python side effect (print, mutation) inside either branch will execute at trace time for both branches.
Problem
Implement conditional_classify(x, threshold) that returns 1.0 if
sum(x) > threshold, else -1.0, using lax.cond.
-
x: 1-D JAX array. -
threshold: scalar. - Returns: scalar float.
Hints
Sign in to attempt this problem and view the solution.