We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
Tracer Leak Detection
Why this matters
A “tracer leak” is what happens when JAX code tries to use the value of a
traced array in a Python control-flow context — a Python if, a bool() call,
list slicing with traced indices, and so on. During jax.jit tracing, array
values are replaced by abstract “tracer” objects that carry only shape and dtype
information. If your code tries to branch on a tracer’s value, JAX cannot
compile a single XLA program that is correct for all possible values, so it
raises a TracerBoolConversionError (older JAX) or ConcretizationTypeError
(modern JAX).
The fix is to replace Python-level conditionals with array-aware primitives
that JAX understands as part of the trace: jnp.where, jnp.maximum,
jax.lax.cond, etc. These operations operate element-wise on abstract arrays
and produce well-shaped outputs regardless of concrete values, so JAX can
compile them without knowing the numeric contents at trace time.
Worked mini-example
import jax
import jax.numpy as jnp
# ❌ Tracer leak — Python `if` on a traced value
def clip_v1(x):
result = x
if x[0] < 0: # error: bool conversion of tracer
result = result.at[0].set(0)
return result
# ❌ Index-based loop is also a leak under jit
def clip_v2(x):
for i in range(len(x)):
if x[i] < 0:
x = x.at[i].set(0)
return x
# ✅ Vectorized & jit-safe
def clip_v3(x):
return jnp.where(x < 0, 0, x)
# Equivalent: jnp.maximum(x, 0)
clip_v1 and clip_v2 both raise TracerBoolConversionError under
jax.jit because x[i] < 0 is a traced boolean — its value is unknown
at trace time. clip_v3 works because jnp.where is a JAX primitive that
compiles down to a vectorized select instruction without needing concrete
values.
Common pitfalls
-
if x[0] < 0:— Pythonifon a traced value raisesTracerBoolConversionError/ConcretizationTypeError. Usejnp.whereorjax.lax.condinstead. -
bool(traced_array)— same error. Boolean conversion of a tracer is impossible because the value is abstract. -
for i in range(len(x))—len(x)is fine (it returns the static shape), but per-element conditionals inside the loop invite the same tracer leak. -
x.tolist()— forces materialization of concrete values; raises the same error underjit. -
assert x[0] > 0—assertcallsbool()internally; same tracer leak. -
if x.min() < 0:—x.min()returns a 0-D traced array; branching on it is still a leak.
Problem
Implement safe_clip_positive(x) that replaces every negative element of a
1-D JAX array x with 0, leaving non-negative elements unchanged.
Two illustrative examples (not from the test set):
-
x = jnp.array([3.0, -7.0, 0.5])→[3.0, 0.0, 0.5] -
x = jnp.array([-0.1, 100.0, -0.1])→[0.0, 100.0, 0.0]
The implementation must be jit-safe — no Python if on traced values.
Hints
Sign in to attempt this problem and view the solution.