We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Tracing: Shape vs Value
Why this matters
JAX has two completely different kinds of branching, and mixing them up is the #1 source of cryptic errors for developers coming from PyTorch or NumPy.
Shape-based branching is safe. A JAX array’s .shape is a plain Python
tuple of integers — it is known the moment the function is traced, before any
real data arrives. Writing if x.shape[0] >= 3: is just a Python if on a
Python int. JAX’s tracer never sees it; it resolves at trace time and bakes the
chosen branch permanently into the compiled program.
Value-based branching is dangerous. Under jax.jit, array values are
replaced by abstract “tracer” objects. A tracer knows its shape and dtype, but
not its numeric value. If you write if jnp.sum(x) > threshold:, JAX must
evaluate jnp.sum(x) > threshold to decide which branch to take — but the
result is still a tracer, not a real Python bool. Python’s if tries to
convert that to bool, which raises:
TracerBoolConversionError: Attempted boolean conversion of traced array
The fix is jax.lax.cond(predicate, true_fn, false_fn), which is a proper
JAX primitive that works with traced values. Both branches are represented in
the compiled program, and the runtime picks one based on the actual value of
predicate. Unlike jnp.where, only the selected branch actually executes
at runtime — the other branch’s side effects (if any) do not run.
Distinguishing these two kinds is a core JAX skill. Get it wrong and you get
cryptic errors or silent bugs; get it right and your function composes cleanly
with jit, grad, and vmap.
Worked mini-example
import jax.numpy as jnp
from jax import lax, jit
# ✅ OK — branching on a static shape (resolved at trace time)
def f(x):
if x.shape[-1] == 3: # shape is a Python int; Python if is fine
return x.sum(axis=-1)
else:
return x.mean(axis=-1)
# ❌ BREAKS under jit — branching on a runtime value
def g(x):
if x[0] > 0: # x[0] is a Tracer; raises TracerBoolConversionError
return x * 2
else:
return -x
# ✅ Fix with lax.cond — both branches compiled, runtime picks one
def g_fixed(x):
return lax.cond(
x[0] > 0, # predicate: a traced scalar bool
lambda: x * 2, # true_fn: called when predicate is True
lambda: -x # false_fn: called when predicate is False
)
jit(f)(jnp.ones(3)) # works — shape branch resolved at trace time
jit(g)(jnp.ones(3)) # raises TracerBoolConversionError
jit(g_fixed)(jnp.ones(3)) # works — lax.cond handles traced values
f is fine because x.shape[-1] is 3, a plain Python int, at trace time.
g fails because x[0] is a Tracer, and Tracer.__bool__ raises.
g_fixed works because lax.cond is a JAX primitive: it records both branches
during tracing and emits an XLA conditional that evaluates at run time.
Common pitfalls
-
ifon a Tracer: raisesTracerBoolConversionError: Abstract tracer value encountered in conditional— the key phrase to search for when debugging. -
x.shapereturns a tuple of Python ints: this is static, computed at trace time;if x.shape[0] >= 3:is always safe and never involves a tracer. -
jnp.sum(x)returns a Tracer (under jit): comparing it with>produces another Tracer;if-ing on that Tracer fails just likegabove. -
lax.cond(pred, true_fn, false_fn): both branches must return arrays with matching shape and dtype — the compiler needs a single output type. They take no positional args here; they close overxandthreshold. - Don’t try to make both branches return arrays of different lengths — they must broadcast to a single output shape or the compiler will error.
Problem
Implement shape_or_value_branch(x, threshold) that takes a 1-D JAX array x
and a float scalar threshold, and returns a 1-D array of shape (3,):
-
If
x.shape[0] >= 3(static check — use a plain Pythonif):-
If
jnp.sum(x) > threshold(dynamic check — uselax.cond): returnx[:3] * 2.0 -
Otherwise: return
x[:3] + 1.0
-
If
-
If
x.shape[0] < 3: returnjnp.zeros(3) + threshold
The function accepts arrays of any 1-D length. The visible test cases below
show the basic behaviors; one additional hidden case verifies the boundary
where x.shape[0] == 3.
For intuition, two illustrative examples:
-
x=[1,2,3,4,5],threshold=5.0→[2.0, 4.0, 6.0](sum=15 > 5 →x[:3]*2) -
x=[0.5, 0.5],threshold=1.0→[1.0, 1.0, 1.0](shape < 3 →zeros+threshold)
Hints
Sign in to attempt this problem and view the solution.