We can't find the internet
Attempting to reconnect
Something went wrong!
Attempting to reconnect
disable_jit Context Manager
Why this matters
When you’re debugging a jitted function, JIT tracing can make things harder:
Python print statements get swallowed during tracing, if branches may
not execute the path you expect, and exceptions come with unhelpful
abstracted traceback lines.
jax.disable_jit() is a context manager that makes all jax.jit
decorators and calls behave as no-ops for the duration of the block.
Inside the context, every JAX operation executes eagerly — exactly like
pure NumPy — so print, assert, and plain Python if all work as you’d
expect:
with jax.disable_jit():
result = my_jitted_function(x)
# ^ runs eagerly; prints, asserts, and Python ifs all fire
This is purely a debugging tool. Never ship code wrapped in
disable_jit: it disables all compilation and makes JAX programs slower
than NumPy.
Worked mini-example
import jax
import jax.numpy as jnp
@jax.jit
def noisy_sum(x):
print("tracing!") # fires during trace, silent on cache hit
return jnp.sum(x)
x = jnp.array([1.0, 2.0, 3.0])
noisy_sum(x) # prints "tracing!" once (during compile)
noisy_sum(x) # silent (cache hit)
with jax.disable_jit():
noisy_sum(x) # prints "tracing!" every call (eager mode)
Common pitfalls
- Don’t leave disable_jit in production: every call recomputes eagerly; throughput collapses to single-threaded NumPy performance.
-
Nested jit calls are also disabled: all
jax.jitwithin the context (including library calls) become no-ops. -
Randomness still requires explicit keys:
disable_jitdoes not enable NumPy-style stateful random — you still needjax.random.PRNGKey.
Problem
Implement with_jit_disabled_sum(x) that computes jnp.sum(x ** 2)
inside a jax.disable_jit() context and returns the result.
-
x: 1-D JAX array.
Returns: scalar — sum(x**2).
Example (not from the test set):
-
with_jit_disabled_sum(jnp.array([3.0, 4.0]))returns25.0.
Hints
Sign in to attempt this problem and view the solution.