medium primitives

disable_jit Context Manager

Why this matters

When you’re debugging a jitted function, JIT tracing can make things harder: Python print statements get swallowed during tracing, if branches may not execute the path you expect, and exceptions come with unhelpful abstracted traceback lines.

jax.disable_jit() is a context manager that makes all jax.jit decorators and calls behave as no-ops for the duration of the block. Inside the context, every JAX operation executes eagerly — exactly like pure NumPy — so print, assert, and plain Python if all work as you’d expect:

with jax.disable_jit():
    result = my_jitted_function(x)
    # ^ runs eagerly; prints, asserts, and Python ifs all fire

This is purely a debugging tool. Never ship code wrapped in disable_jit: it disables all compilation and makes JAX programs slower than NumPy.

Worked mini-example

import jax
import jax.numpy as jnp

@jax.jit
def noisy_sum(x):
    print("tracing!")      # fires during trace, silent on cache hit
    return jnp.sum(x)

x = jnp.array([1.0, 2.0, 3.0])

noisy_sum(x)              # prints "tracing!" once (during compile)
noisy_sum(x)              # silent (cache hit)

with jax.disable_jit():
    noisy_sum(x)          # prints "tracing!" every call (eager mode)

Common pitfalls

  • Don’t leave disable_jit in production: every call recomputes eagerly; throughput collapses to single-threaded NumPy performance.
  • Nested jit calls are also disabled: all jax.jit within the context (including library calls) become no-ops.
  • Randomness still requires explicit keys: disable_jit does not enable NumPy-style stateful random — you still need jax.random.PRNGKey.

Problem

Implement with_jit_disabled_sum(x) that computes jnp.sum(x ** 2) inside a jax.disable_jit() context and returns the result.

  • x: 1-D JAX array.

Returns: scalar — sum(x**2).

Example (not from the test set):

  • with_jit_disabled_sum(jnp.array([3.0, 4.0])) returns 25.0.

Hints

jax disable-jit debugging

Sign in to attempt this problem and view the solution.