medium end_to_end

Tracing: Shape vs Value

Why this matters

JAX has two completely different kinds of branching, and mixing them up is the #1 source of cryptic errors for developers coming from PyTorch or NumPy.

Shape-based branching is safe. A JAX array’s .shape is a plain Python tuple of integers — it is known the moment the function is traced, before any real data arrives. Writing if x.shape[0] >= 3: is just a Python if on a Python int. JAX’s tracer never sees it; it resolves at trace time and bakes the chosen branch permanently into the compiled program.

Value-based branching is dangerous. Under jax.jit, array values are replaced by abstract “tracer” objects. A tracer knows its shape and dtype, but not its numeric value. If you write if jnp.sum(x) > threshold:, JAX must evaluate jnp.sum(x) > threshold to decide which branch to take — but the result is still a tracer, not a real Python bool. Python’s if tries to convert that to bool, which raises:

TracerBoolConversionError: Attempted boolean conversion of traced array

The fix is jax.lax.cond(predicate, true_fn, false_fn), which is a proper JAX primitive that works with traced values. Both branches are represented in the compiled program, and the runtime picks one based on the actual value of predicate. Unlike jnp.where, only the selected branch actually executes at runtime — the other branch’s side effects (if any) do not run.

Distinguishing these two kinds is a core JAX skill. Get it wrong and you get cryptic errors or silent bugs; get it right and your function composes cleanly with jit, grad, and vmap.

Worked mini-example

import jax.numpy as jnp
from jax import lax, jit

# ✅ OK — branching on a static shape (resolved at trace time)
def f(x):
    if x.shape[-1] == 3:          # shape is a Python int; Python if is fine
        return x.sum(axis=-1)
    else:
        return x.mean(axis=-1)

# ❌ BREAKS under jit — branching on a runtime value
def g(x):
    if x[0] > 0:                  # x[0] is a Tracer; raises TracerBoolConversionError
        return x * 2
    else:
        return -x

# ✅ Fix with lax.cond — both branches compiled, runtime picks one
def g_fixed(x):
    return lax.cond(
        x[0] > 0,                 # predicate: a traced scalar bool
        lambda: x * 2,            # true_fn:  called when predicate is True
        lambda: -x                # false_fn: called when predicate is False
    )

jit(f)(jnp.ones(3))   # works — shape branch resolved at trace time
jit(g)(jnp.ones(3))   # raises TracerBoolConversionError
jit(g_fixed)(jnp.ones(3))  # works — lax.cond handles traced values

f is fine because x.shape[-1] is 3, a plain Python int, at trace time. g fails because x[0] is a Tracer, and Tracer.__bool__ raises. g_fixed works because lax.cond is a JAX primitive: it records both branches during tracing and emits an XLA conditional that evaluates at run time.

Common pitfalls

  • if on a Tracer: raises TracerBoolConversionError: Abstract tracer value encountered in conditional — the key phrase to search for when debugging.
  • x.shape returns a tuple of Python ints: this is static, computed at trace time; if x.shape[0] >= 3: is always safe and never involves a tracer.
  • jnp.sum(x) returns a Tracer (under jit): comparing it with > produces another Tracer; if-ing on that Tracer fails just like g above.
  • lax.cond(pred, true_fn, false_fn): both branches must return arrays with matching shape and dtype — the compiler needs a single output type. They take no positional args here; they close over x and threshold.
  • Don’t try to make both branches return arrays of different lengths — they must broadcast to a single output shape or the compiler will error.

Problem

Implement shape_or_value_branch(x, threshold) that takes a 1-D JAX array x and a float scalar threshold, and returns a 1-D array of shape (3,):

  • If x.shape[0] >= 3 (static check — use a plain Python if):
    • If jnp.sum(x) > threshold (dynamic check — use lax.cond): return x[:3] * 2.0
    • Otherwise: return x[:3] + 1.0
  • If x.shape[0] < 3: return jnp.zeros(3) + threshold

The function accepts arrays of any 1-D length. The visible test cases below show the basic behaviors; one additional hidden case verifies the boundary where x.shape[0] == 3.

For intuition, two illustrative examples:

  • x=[1,2,3,4,5], threshold=5.0[2.0, 4.0, 6.0] (sum=15 > 5 → x[:3]*2)
  • x=[0.5, 0.5], threshold=1.0[1.0, 1.0, 1.0] (shape < 3 → zeros+threshold)

Hints

jax tracing lax-cond

Sign in to attempt this problem and view the solution.