easy primitives

Binary Classification via lax.cond

Why this matters

lax.cond(pred, true_fn, false_fn) is JAX’s jit-safe conditional — the functional replacement for Python’s if/else inside traced code. Under jit, Python if statements work only when the condition is a static Python bool. As soon as the predicate depends on a JAX array (e.g. the result of jnp.sum), the branch must be expressed with lax.cond or jnp.where.

Unlike jnp.where, which materialises both branches for every element, lax.cond ensures only one branch executes at runtime — though both are traced at compile time. This matters for side effects (none allowed), but means you get proper short-circuit semantics without computing the unused path on accelerators.

Worked mini-example

from jax import lax
import jax.numpy as jnp

def sign(x):
    # Returns 1.0 if x > 0, else -1.0 — jit-safe
    return lax.cond(x > 0, lambda: 1.0, lambda: -1.0)

sign(jnp.array(3.0))   # → 1.0
sign(jnp.array(-2.0))  # → -1.0

true_fn and false_fn are zero-argument callables that return the branch value. Both must produce identical shape and dtype — XLA lowers both to a select instruction, so mismatched types cause a trace-time error.

Common pitfalls

  • Python if under jit. if jnp.sum(x) > threshold: raises a ConcretizationTypeError under jit. Use lax.cond instead.
  • Shape/dtype mismatch. lambda: 1 (int) vs lambda: 1.0 (float) will silently work in eager mode but may fail or produce wrong dtypes; always use matching literals (1.0 / -1.0).
  • Both branches traced. Any Python side effect (print, mutation) inside either branch will execute at trace time for both branches.

Problem

Implement conditional_classify(x, threshold) that returns 1.0 if sum(x) > threshold, else -1.0, using lax.cond.

  • x: 1-D JAX array.
  • threshold: scalar.
  • Returns: scalar float.

Hints

jax lax-cond branching

Sign in to attempt this problem and view the solution.